Skip to content

Commit

Permalink
packAndisp
Browse files Browse the repository at this point in the history
  • Loading branch information
Sullivan-24 committed Aug 9, 2024
1 parent f2c855e commit 87ef82e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 15 deletions.
4 changes: 2 additions & 2 deletions examples/meta_llama/Meta_Llama_3_1_8B/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@
"""
parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=4, mode="isp"),
tensor=dict(size=2, mode="isp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=4, overlap=False, memory_pool=True),
weight=dict(size=2, overlap=False, memory_pool=True),
)

cudnn_deterministic = False
Expand Down
25 changes: 12 additions & 13 deletions huggingface_model/meta_llama/Meta_Llama_3_1_8B/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
#from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand All @@ -50,10 +49,6 @@
from .configuration_llama import LlamaConfig
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
# try:
# from flash_attn import flash_attn_func, flash_attn_varlen_func
# except:
# pass
from internlm.model.modules.embedding import Embedding1D
from internlm.model.modules.linear import new_linear
from internlm.model.ops.attention import hf_q_k_v_with_cu_seqlens, hf_q_k_v_without_cu_seqlens
Expand Down Expand Up @@ -214,7 +209,7 @@ def forward(self, x, position_ids):

# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids.unsqueeze(0)[:, None, :].float()
position_ids_expanded = position_ids.unsqueeze(1)[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
Expand Down Expand Up @@ -510,7 +505,6 @@ def forward(
assert bsz == 1, "hidden_states should be packed into bsz=1 when use_packed_dataset=True"
cu_seqlens = gpc.config.data[f"cu_seqlens_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"]
max_seqlen = gpc.config.data[f"max_seqlen_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"]
#position_ids = position_ids.unsqueeze(0)

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
Expand All @@ -522,7 +516,7 @@ def forward(
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

#import pdb;pdb.set_trace()
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
Expand All @@ -533,7 +527,11 @@ def forward(
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

cos_temp = cos.transpose(0,1).unsqueeze(0)
sin_temp = sin.transpose(0,1).unsqueeze(0)
query_states = (query_states * cos_temp) + (rotate_half(query_states) * sin_temp)
key_states = (key_states * cos_temp) + (rotate_half(key_states) * sin_temp)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
Expand Down Expand Up @@ -573,7 +571,7 @@ def forward(
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

if use_packed_dataset:
attn_output = hf_q_k_v_with_cu_seqlens(
query_states,
Expand All @@ -582,12 +580,12 @@ def forward(
cumulative_len=cu_seqlens,
max_seqlen=max_seqlen,
dropout_p=dropout_rate,
).unsqueeze(0)
)
else:
attn_output = hf_q_k_v_without_cu_seqlens(
query_states, key_states, value_states, dropout_p=dropout_rate, softmax_scale=None, causal=True,
)

attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)

Expand Down Expand Up @@ -643,7 +641,7 @@ def forward(
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

import pdb;pdb.set_trace()
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
Expand All @@ -654,6 +652,7 @@ def forward(
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
Expand Down

0 comments on commit 87ef82e

Please sign in to comment.