Skip to content

Commit

Permalink
pack&isp
Browse files Browse the repository at this point in the history
  • Loading branch information
Sullivan-24 committed Aug 14, 2024
1 parent 87ef82e commit 9084462
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 32 deletions.
8 changes: 0 additions & 8 deletions examples/meta_llama/Meta_Llama_3_1_8B/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,6 @@ def main(args):
# register huggingface model and config for InternEvo
model_initializer.register_module(gpc.config.model_type, LlamaForCausalLM)
hf_config_initializer.register_module(gpc.config.model_type, LlamaConfig)
# if gpc.config.model_type == "HF":
# hf_config_builder = hf_config_initializer.get_module(module_name=gpc.config.model_type)
# hf_cfg = hf_config_builder(return_dict=False)
# gpc.config.model.num_layers = hf_cfg.num_hidden_layers
# gpc.config.model.hidden_size = hf_cfg.hidden_size
# gpc.config.model.num_attention_heads = hf_cfg.num_attention_heads
# gpc.config.model.mlp_ratio = hf_cfg.intermediate_size / hf_cfg.hidden_size
# gpc.config.model.vocab_size = hf_cfg.vocab_size

# initialize model
model = initialize_model()
Expand Down
61 changes: 37 additions & 24 deletions huggingface_model/meta_llama/Meta_Llama_3_1_8B/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,20 @@
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.context.parallel_context import (
IS_REPLICA_ZERO_PARALLEL,
IS_WEIGHT_ZERO_PARALLEL,
)
from internlm.model.modules.embedding import Embedding1D
from internlm.model.modules.linear import new_linear
from internlm.model.ops.attention import (
hf_q_k_v_with_cu_seqlens,
hf_q_k_v_without_cu_seqlens,
)
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
Expand All @@ -46,13 +57,8 @@
logging,
replace_return_docstrings,
)

from .configuration_llama import LlamaConfig
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.modules.embedding import Embedding1D
from internlm.model.modules.linear import new_linear
from internlm.model.ops.attention import hf_q_k_v_with_cu_seqlens, hf_q_k_v_without_cu_seqlens
from internlm.core.context.parallel_context import IS_REPLICA_ZERO_PARALLEL, IS_WEIGHT_ZERO_PARALLEL

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -105,9 +111,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype)

return causal_mask

Expand Down Expand Up @@ -303,9 +307,7 @@ def forward(self, x):
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)

gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
)
gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)

intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
Expand Down Expand Up @@ -362,8 +364,12 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
)

self.q_proj = new_linear("wq", self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = new_linear("wk", self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = new_linear("wv", self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = new_linear(
"wk", self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = new_linear(
"wv", self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = new_linear("wo", self.hidden_size, self.hidden_size, bias=config.attention_bias)

# TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)
Expand Down Expand Up @@ -500,7 +506,7 @@ def forward(
bsz, q_len, _ = hidden_states.size()

use_packed_dataset = gpc.config.data.get("use_packed_dataset", False)

if use_packed_dataset:
assert bsz == 1, "hidden_states should be packed into bsz=1 when use_packed_dataset=True"
cu_seqlens = gpc.config.data[f"cu_seqlens_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"]
Expand All @@ -516,7 +522,7 @@ def forward(
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
#import pdb;pdb.set_trace()
# import pdb;pdb.set_trace()
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
Expand All @@ -528,8 +534,8 @@ def forward(
else:
cos, sin = position_embeddings

cos_temp = cos.transpose(0,1).unsqueeze(0)
sin_temp = sin.transpose(0,1).unsqueeze(0)
cos_temp = cos.transpose(0, 1).unsqueeze(0)
sin_temp = sin.transpose(0, 1).unsqueeze(0)
query_states = (query_states * cos_temp) + (rotate_half(query_states) * sin_temp)
key_states = (key_states * cos_temp) + (rotate_half(key_states) * sin_temp)

Expand Down Expand Up @@ -571,7 +577,7 @@ def forward(
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

if use_packed_dataset:
attn_output = hf_q_k_v_with_cu_seqlens(
query_states,
Expand All @@ -583,9 +589,14 @@ def forward(
)
else:
attn_output = hf_q_k_v_without_cu_seqlens(
query_states, key_states, value_states, dropout_p=dropout_rate, softmax_scale=None, causal=True,
query_states,
key_states,
value_states,
dropout_p=dropout_rate,
softmax_scale=None,
causal=True,
)

attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)

Expand Down Expand Up @@ -641,7 +652,9 @@ def forward(
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
import pdb;pdb.set_trace()
import pdb

pdb.set_trace()
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
Expand All @@ -652,7 +665,7 @@ def forward(
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
Expand Down

0 comments on commit 9084462

Please sign in to comment.