Skip to content

Commit

Permalink
Qwen2 using packed dataset (#14)
Browse files Browse the repository at this point in the history
* add initial README for packed training and isp adaption

* Qwen2 using packed dataset

* mv README.md to huggingface_model fodler

* runnable ISP aligned

---------

Co-authored-by: zigzagcai <[email protected]>
  • Loading branch information
guojihu-hana and zigzagcai authored Sep 5, 2024
1 parent 4eb14c5 commit 37177a5
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 58 deletions.
8 changes: 4 additions & 4 deletions examples/Qwen/Qwen2_7B/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# import os
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
CHECKPOINT_EVERY = 10000
CHECKPOINT_EVERY = 100000
ckpt = dict(
enable_save_ckpt=False, # enable ckpt save.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
Expand Down Expand Up @@ -68,7 +68,7 @@
valid_folder=VALID_FOLDER,
empty_cache_and_diag_interval=200,
diag_outlier_ratio=1.1,
use_packed_dataset=False,
use_packed_dataset=True,
# whether use shared memory to load meta files
use_shm=False,
# when use shm, the default shm_path is "/dev/shm/metacache"
Expand Down Expand Up @@ -146,7 +146,7 @@
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=False,
use_flash_attn=True,
# Whether the odd and even columns of the query and key in the model are normally interleaved.
# If it's True, the model's odd and even columns are normally ordered; if it's False,
# it means that the model has prematurely concatenated all odd columns and even columns in front
Expand Down Expand Up @@ -185,7 +185,7 @@
"""
parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=1, mode="mtp"),
tensor=dict(size=1, mode="isp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=False, memory_pool=True),
)
Expand Down
8 changes: 0 additions & 8 deletions examples/Qwen/Qwen2_7B/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,6 @@ def main(args):
# register huggingface model and config for InternEvo
model_initializer.register_module(gpc.config.model_type, Qwen2ForCausalLM)
hf_config_initializer.register_module(gpc.config.model_type, Qwen2Config)
if gpc.config.model_type == "hf":
hf_config_builder = hf_config_initializer.get_module(module_name=gpc.config.model_type)
hf_cfg = hf_config_builder(return_dict=False)
gpc.config.model.num_layers = hf_cfg.num_hidden_layers
gpc.config.model.hidden_size = hf_cfg.hidden_size
gpc.config.model.num_attention_heads = hf_cfg.num_attention_heads
gpc.config.model.mlp_ratio = hf_cfg.intermediate_size / hf_cfg.hidden_size
gpc.config.model.vocab_size = hf_cfg.vocab_size

# initialize model
model = initialize_model()
Expand Down
132 changes: 86 additions & 46 deletions huggingface_model/Qwen/Qwen2_7B/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Qwen2 model."""

import math
from typing import List, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.context.parallel_context import IS_REPLICA_ZERO_PARALLEL

# from internlm.model.ops.rotary_emb import apply_rotary_emb
from internlm.model.modules.embedding import Embedding1D
from internlm.model.modules.linear import new_linear
from internlm.model.ops.attention import (
isp_flash_attn_varlen_func,
isp_flash_attn_func,
)
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
)
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand All @@ -47,8 +54,8 @@
logging,
replace_return_docstrings,
)
from .configuration_qwen2 import Qwen2Config

from .configuration_qwen2 import Qwen2Config

if is_flash_attn_2_available():
from transformers.modeling_flash_attention_utils import _flash_attention_forward
Expand All @@ -70,6 +77,8 @@ def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
for param in self.parameters():
setattr(param, IS_REPLICA_ZERO_PARALLEL, True)

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
Expand Down Expand Up @@ -159,9 +168,13 @@ def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
# self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
# self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
# self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.gate_proj = new_linear("w1", self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = new_linear("w2", self.intermediate_size, self.hidden_size, bias=False)
self.up_proj = new_linear("w3", self.hidden_size, self.intermediate_size, bias=False)

self.act_fn = ACT2FN[config.hidden_act]

def forward(self, hidden_state):
Expand Down Expand Up @@ -213,10 +226,14 @@ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
# self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
# self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
# self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
# self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.q_proj = new_linear("wq", self.hidden_size, self.num_heads * self.head_dim, bias=True)
self.k_proj = new_linear("wk", self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = new_linear("wv", self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.o_proj = new_linear("wo", self.num_heads * self.head_dim, self.hidden_size, bias=False)

self.rotary_emb = Qwen2RotaryEmbedding(
self.head_dim,
Expand Down Expand Up @@ -328,6 +345,8 @@ def forward(
):
bsz, q_len, _ = hidden_states.size()

use_packed_dataset = gpc.config.data.get("use_packed_dataset", False)

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
Expand All @@ -340,17 +359,33 @@ def forward(
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"The cache structure has changed since version v4.36. "
f"If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, "
"please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

# Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if use_packed_dataset:
assert bsz == 1, "hidden_states should be packed into bsz=1 when use_packed_dataset=True"
cu_seqlens = gpc.config.data[f"cu_seqlens_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"]
max_seqlen = gpc.config.data[f"max_seqlen_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"]

assert position_ids is not None
rotary_seq_len = max(kv_seq_len, position_ids.max().item()) + 1
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
# query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
# cos, sin, position_ids)
cos_temp = cos[position_ids].unsqueeze(0).unsqueeze(0)
sin_temp = sin[position_ids].unsqueeze(0).unsqueeze(0)
query_states = (query_states * cos_temp) + (rotate_half(query_states) * sin_temp)
key_states = (key_states * cos_temp) + (rotate_half(key_states) * sin_temp)
else:
# Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
Expand Down Expand Up @@ -414,27 +449,29 @@ def forward(
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
sliding_window = self.config.sliding_window
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
sliding_window = None

attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=sliding_window,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
causal = self.is_causal and q_len != 1

if use_packed_dataset:
attn_output = isp_flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
causal=causal,
attention_dropout=dropout_rate,
)
else:
attn_output = isp_flash_attn_func(
query_states,
key_states,
value_states,
causal=causal,
attention_dropout=dropout_rate,
)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
Expand Down Expand Up @@ -659,7 +696,7 @@ def _init_weights(self, module):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
elif isinstance(module, Embedding1D):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
Expand Down Expand Up @@ -756,7 +793,11 @@ def __init__(self, config: Qwen2Config):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
# self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.embed_tokens = Embedding1D(
num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, padding_idx=self.padding_idx
)

self.layers = nn.ModuleList(
[Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
Expand Down Expand Up @@ -944,9 +985,7 @@ def _update_causal_mask(
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
Expand Down Expand Up @@ -980,7 +1019,8 @@ def __init__(self, config):
super().__init__(config)
self.model = Qwen2Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.lm_head = new_linear("head", config.hidden_size, config.vocab_size, bias=False)

# Initialize weights and apply final processing
self.post_init()
Expand Down

0 comments on commit 37177a5

Please sign in to comment.