Skip to content

Commit

Permalink
llama3_1_8b_pack&isp
Browse files Browse the repository at this point in the history
  • Loading branch information
Sullivan-24 committed Aug 6, 2024
1 parent 8a7fcea commit f2c855e
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 62 deletions.
10 changes: 5 additions & 5 deletions examples/meta_llama/Meta_Llama_3_1_8B/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
JOB_NAME = f"train_{model_type}/{HF_MODEL_NAME}"
DO_ALERT = False

SEQ_LEN = 1024
SEQ_LEN = 2048

MODEL_ONLY_FOLDER = HF_MODEL_NAME
# Ckpt folder format:
Expand Down Expand Up @@ -68,7 +68,7 @@
valid_folder=VALID_FOLDER,
empty_cache_and_diag_interval=200,
diag_outlier_ratio=1.1,
use_packed_dataset=False,
use_packed_dataset=True,
# whether use shared memory to load meta files
use_shm=False,
# when use shm, the default shm_path is "/dev/shm/metacache"
Expand Down Expand Up @@ -146,7 +146,7 @@
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=False,
use_flash_attn=True,
# Whether the odd and even columns of the query and key in the model are normally interleaved.
# If it's True, the model's odd and even columns are normally ordered; if it's False,
# it means that the model has prematurely concatenated all odd columns and even columns in front
Expand Down Expand Up @@ -185,9 +185,9 @@
"""
parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=1, mode="mtp"),
tensor=dict(size=4, mode="isp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=False, memory_pool=True),
weight=dict(size=4, overlap=False, memory_pool=True),
)

cudnn_deterministic = False
Expand Down
16 changes: 8 additions & 8 deletions examples/meta_llama/Meta_Llama_3_1_8B/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ def main(args):
# register huggingface model and config for InternEvo
model_initializer.register_module(gpc.config.model_type, LlamaForCausalLM)
hf_config_initializer.register_module(gpc.config.model_type, LlamaConfig)
if gpc.config.model_type == "hf":
hf_config_builder = hf_config_initializer.get_module(module_name=gpc.config.model_type)
hf_cfg = hf_config_builder(return_dict=False)
gpc.config.model.num_layers = hf_cfg.num_hidden_layers
gpc.config.model.hidden_size = hf_cfg.hidden_size
gpc.config.model.num_attention_heads = hf_cfg.num_attention_heads
gpc.config.model.mlp_ratio = hf_cfg.intermediate_size / hf_cfg.hidden_size
gpc.config.model.vocab_size = hf_cfg.vocab_size
# if gpc.config.model_type == "HF":
# hf_config_builder = hf_config_initializer.get_module(module_name=gpc.config.model_type)
# hf_cfg = hf_config_builder(return_dict=False)
# gpc.config.model.num_layers = hf_cfg.num_hidden_layers
# gpc.config.model.hidden_size = hf_cfg.hidden_size
# gpc.config.model.num_attention_heads = hf_cfg.num_attention_heads
# gpc.config.model.mlp_ratio = hf_cfg.intermediate_size / hf_cfg.hidden_size
# gpc.config.model.vocab_size = hf_cfg.vocab_size

# initialize model
model = initialize_model()
Expand Down
188 changes: 139 additions & 49 deletions huggingface_model/meta_llama/Meta_Llama_3_1_8B/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_flash_attention_utils import _flash_attention_forward
#from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand All @@ -48,13 +48,75 @@
replace_return_docstrings,
)
from .configuration_llama import LlamaConfig

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
# try:
# from flash_attn import flash_attn_func, flash_attn_varlen_func
# except:
# pass
from internlm.model.modules.embedding import Embedding1D
from internlm.model.modules.linear import new_linear
from internlm.model.ops.attention import hf_q_k_v_with_cu_seqlens, hf_q_k_v_without_cu_seqlens
from internlm.core.context.parallel_context import IS_REPLICA_ZERO_PARALLEL, IS_WEIGHT_ZERO_PARALLEL

logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "LlamaConfig"


def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)

return causal_mask


class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Expand All @@ -63,6 +125,10 @@ def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
for param in self.parameters():
setattr(param, IS_REPLICA_ZERO_PARALLEL, True)
for weight in self.weight:
setattr(weight, IS_WEIGHT_ZERO_PARALLEL, True)

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
Expand Down Expand Up @@ -148,7 +214,7 @@ def forward(self, x, position_ids):

# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
position_ids_expanded = position_ids.unsqueeze(0)[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
Expand Down Expand Up @@ -230,9 +296,9 @@ def __init__(self, config):
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.gate_proj = new_linear("w1", self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = new_linear("w2", self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = new_linear("w3", self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]

def forward(self, x):
Expand Down Expand Up @@ -300,10 +366,10 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
f" and `num_heads`: {self.num_heads})."
)

self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
self.q_proj = new_linear("wq", self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = new_linear("wk", self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = new_linear("wv", self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = new_linear("wo", self.hidden_size, self.hidden_size, bias=config.attention_bias)

# TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
Expand Down Expand Up @@ -438,6 +504,14 @@ def forward(

bsz, q_len, _ = hidden_states.size()

use_packed_dataset = gpc.config.data.get("use_packed_dataset", False)

if use_packed_dataset:
assert bsz == 1, "hidden_states should be packed into bsz=1 when use_packed_dataset=True"
cu_seqlens = gpc.config.data[f"cu_seqlens_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"]
max_seqlen = gpc.config.data[f"max_seqlen_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"]
#position_ids = position_ids.unsqueeze(0)

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
Expand Down Expand Up @@ -500,18 +574,19 @@ def forward(
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
)
if use_packed_dataset:
attn_output = hf_q_k_v_with_cu_seqlens(
query_states,
key_states,
value_states,
cumulative_len=cu_seqlens,
max_seqlen=max_seqlen,
dropout_p=dropout_rate,
).unsqueeze(0)
else:
attn_output = hf_q_k_v_without_cu_seqlens(
query_states, key_states, value_states, dropout_p=dropout_rate, softmax_scale=None, causal=True,
)

attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
Expand Down Expand Up @@ -747,7 +822,7 @@ def _init_weights(self, module):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
elif isinstance(module, Embedding1D):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
Expand Down Expand Up @@ -844,7 +919,7 @@ def __init__(self, config: LlamaConfig):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.embed_tokens = Embedding1D(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
Expand Down Expand Up @@ -882,7 +957,6 @@ def forward(
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict


if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
Expand All @@ -893,7 +967,7 @@ def forward(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

Expand Down Expand Up @@ -1030,27 +1104,18 @@ def _update_causal_mask(
else past_seen_tokens + sequence_length + 1
)

if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)

if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
Expand All @@ -1072,7 +1137,7 @@ def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.lm_head = new_linear("head", config.hidden_size, config.vocab_size, bias=False)

# Initialize weights and apply final processing
self.post_init()
Expand Down Expand Up @@ -1217,11 +1282,36 @@ def prepare_inputs_for_generation(
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
model_inputs = {"input_ids": input_ids}

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device

dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min

attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)

model_inputs.update(
{
Expand Down

0 comments on commit f2c855e

Please sign in to comment.