diff --git a/examples/Qwen/Qwen2_7B/config.py b/examples/Qwen/Qwen2_7B/config.py index 29b9981..ae89b7f 100644 --- a/examples/Qwen/Qwen2_7B/config.py +++ b/examples/Qwen/Qwen2_7B/config.py @@ -15,7 +15,7 @@ # import os # BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint # SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" -CHECKPOINT_EVERY = 10000 +CHECKPOINT_EVERY = 100000 ckpt = dict( enable_save_ckpt=False, # enable ckpt save. save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. @@ -68,7 +68,7 @@ valid_folder=VALID_FOLDER, empty_cache_and_diag_interval=200, diag_outlier_ratio=1.1, - use_packed_dataset=False, + use_packed_dataset=True, # whether use shared memory to load meta files use_shm=False, # when use shm, the default shm_path is "/dev/shm/metacache" @@ -146,7 +146,7 @@ dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" norm_type="rmsnorm", layer_norm_epsilon=1e-5, - use_flash_attn=False, + use_flash_attn=True, # Whether the odd and even columns of the query and key in the model are normally interleaved. # If it's True, the model's odd and even columns are normally ordered; if it's False, # it means that the model has prematurely concatenated all odd columns and even columns in front @@ -185,7 +185,7 @@ """ parallel = dict( zero1=dict(size=-1), - tensor=dict(size=1, mode="mtp"), + tensor=dict(size=1, mode="isp"), pipeline=dict(size=1, interleaved_overlap=True), weight=dict(size=1, overlap=False, memory_pool=True), ) diff --git a/examples/Qwen/Qwen2_7B/train.py b/examples/Qwen/Qwen2_7B/train.py index 20e4c61..e59b9c0 100644 --- a/examples/Qwen/Qwen2_7B/train.py +++ b/examples/Qwen/Qwen2_7B/train.py @@ -22,14 +22,6 @@ def main(args): # register huggingface model and config for InternEvo model_initializer.register_module(gpc.config.model_type, Qwen2ForCausalLM) hf_config_initializer.register_module(gpc.config.model_type, Qwen2Config) - if gpc.config.model_type == "hf": - hf_config_builder = hf_config_initializer.get_module(module_name=gpc.config.model_type) - hf_cfg = hf_config_builder(return_dict=False) - gpc.config.model.num_layers = hf_cfg.num_hidden_layers - gpc.config.model.hidden_size = hf_cfg.hidden_size - gpc.config.model.num_attention_heads = hf_cfg.num_attention_heads - gpc.config.model.mlp_ratio = hf_cfg.intermediate_size / hf_cfg.hidden_size - gpc.config.model.vocab_size = hf_cfg.vocab_size # initialize model model = initialize_model() diff --git a/huggingface_model/Qwen/Qwen2_7B/modeling_qwen2.py b/huggingface_model/Qwen/Qwen2_7B/modeling_qwen2.py index e4871c4..8a2eb89 100644 --- a/huggingface_model/Qwen/Qwen2_7B/modeling_qwen2.py +++ b/huggingface_model/Qwen/Qwen2_7B/modeling_qwen2.py @@ -18,20 +18,27 @@ # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Qwen2 model.""" - import math from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import IS_REPLICA_ZERO_PARALLEL + +# from internlm.model.ops.rotary_emb import apply_rotary_emb +from internlm.model.modules.embedding import Embedding1D +from internlm.model.modules.linear import new_linear +from internlm.model.ops.attention import ( + isp_flash_attn_varlen_func, + isp_flash_attn_func, +) from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache -from transformers.modeling_attn_mask_utils import ( - AttentionMaskConverter, -) +from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -47,8 +54,8 @@ logging, replace_return_docstrings, ) -from .configuration_qwen2 import Qwen2Config +from .configuration_qwen2 import Qwen2Config if is_flash_attn_2_available(): from transformers.modeling_flash_attention_utils import _flash_attention_forward @@ -70,6 +77,8 @@ def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + for param in self.parameters(): + setattr(param, IS_REPLICA_ZERO_PARALLEL, True) def forward(self, hidden_states): input_dtype = hidden_states.dtype @@ -159,9 +168,13 @@ def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + # self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + # self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + # self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.gate_proj = new_linear("w1", self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = new_linear("w2", self.intermediate_size, self.hidden_size, bias=False) + self.up_proj = new_linear("w3", self.hidden_size, self.intermediate_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_state): @@ -213,10 +226,14 @@ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + # self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + # self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + # self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + # self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.q_proj = new_linear("wq", self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = new_linear("wk", self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = new_linear("wv", self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = new_linear("wo", self.num_heads * self.head_dim, self.hidden_size, bias=False) self.rotary_emb = Qwen2RotaryEmbedding( self.head_dim, @@ -328,6 +345,8 @@ def forward( ): bsz, q_len, _ = hidden_states.size() + use_packed_dataset = gpc.config.data.get("use_packed_dataset", False) + query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) @@ -340,17 +359,33 @@ def forward( if past_key_value is not None: if self.layer_idx is None: raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "The cache structure has changed since version v4.36. " + f"If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, " + "please make sure to initialize the attention class " "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + if use_packed_dataset: + assert bsz == 1, "hidden_states should be packed into bsz=1 when use_packed_dataset=True" + cu_seqlens = gpc.config.data[f"cu_seqlens_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] + max_seqlen = gpc.config.data[f"max_seqlen_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] + + assert position_ids is not None + rotary_seq_len = max(kv_seq_len, position_ids.max().item()) + 1 + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + # cos, sin, position_ids) + cos_temp = cos[position_ids].unsqueeze(0).unsqueeze(0) + sin_temp = sin[position_ids].unsqueeze(0).unsqueeze(0) + query_states = (query_states * cos_temp) + (rotate_half(query_states) * sin_temp) + key_states = (key_states * cos_temp) + (rotate_half(key_states) * sin_temp) + else: + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute @@ -414,27 +449,29 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - if ( - self.config.use_sliding_window - and getattr(self.config, "sliding_window", None) is not None - and self.layer_idx >= self.config.max_window_layers - ): - sliding_window = self.config.sliding_window + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal else: - sliding_window = None - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=sliding_window, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + causal = self.is_causal and q_len != 1 + + if use_packed_dataset: + attn_output = isp_flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + causal=causal, + attention_dropout=dropout_rate, + ) + else: + attn_output = isp_flash_attn_func( + query_states, + key_states, + value_states, + causal=causal, + attention_dropout=dropout_rate, + ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) @@ -659,7 +696,7 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, nn.Embedding): + elif isinstance(module, Embedding1D): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() @@ -756,7 +793,11 @@ def __init__(self, config: Qwen2Config): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + # self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_tokens = Embedding1D( + num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, padding_idx=self.padding_idx + ) + self.layers = nn.ModuleList( [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) @@ -944,9 +985,7 @@ def _update_causal_mask( raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") causal_mask = attention_mask else: - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) @@ -980,7 +1019,8 @@ def __init__(self, config): super().__init__(config) self.model = Qwen2Model(config) self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.lm_head = new_linear("head", config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init()