diff --git a/README.md b/README.md index 4ca840392e..62b1d69e01 100644 --- a/README.md +++ b/README.md @@ -89,23 +89,38 @@ This tool helps convert the format of paramters between Megatron-LLaMA/Megatron- **HuggingFace to Megatron-LLaMA** +For LLaMA: ``` -sh tools/checkpoint_conversion/hf_to_megatron.sh +sh tools/checkpoint_conversion/hf_to_megatron.sh +``` +For Baichuan: +``` +sh tools/checkpoint_conversion/baichuan_hf_to_megatron.sh ``` **Megatron-LLaMA to HuggingFace** +For LLaMA: ``` sh tools/checkpoint_conversion/megatron_to_hf.sh ``` +For Baichuan: +``` +sh tools/checkpoint_conversion/baichuan_megatron_to_hf.sh +``` #### B. Launching scripts **Single-node launching** +For LLaMA: ``` sh examples/LLaMA/LLaMA_13_standalone.sh ``` +For Baichuan: +``` +sh examples/Baichuan_13_standalone.sh +``` **Distributed launching** diff --git a/README_zh.md b/README_zh.md index be02c94214..3a41a7d9e9 100644 --- a/README_zh.md +++ b/README_zh.md @@ -85,17 +85,27 @@ Megatron-LLaMA使用方式与Megatron-LM基本一致,详细信息请参考[Meg **HuggingFace to Megatron-LLaMA** +LLaMA: ``` sh tools/checkpoint_conversion/hf_to_megatron.sh ``` +Baichuan: +``` +sh tools/checkpoint_conversion/baichuan_hf_to_megatron.sh +``` 完成训练后,将训练产出的权重转换成HuggingFace支持的格式,方便后续使用: **Megatron-LLaMA to HuggingFace** +LLaMA: ``` sh tools/checkpoint_conversion/megatron_to_hf.sh ``` +Baichuan: +``` +sh tools/checkpoint_conversion/baichuan_megatron_to_hf.sh +``` ### B. LLaMA训练脚本 diff --git a/examples/Baichuan_13_standalone.sh b/examples/Baichuan_13_standalone.sh new file mode 100644 index 0000000000..b2e165fb9d --- /dev/null +++ b/examples/Baichuan_13_standalone.sh @@ -0,0 +1,91 @@ +DATASET_1="" +DATASET_2="" +DATASET_3="" +DATASET="0.2 ${DATASET_1} 0.3 ${DATASET_2} 0.5 ${DATASET_3}" + +TP_SIZE=2 +PP_SIZE=1 +WORLD_SIZE=8 +MICRO_BATCH_SIZE=2 +# The int is the number of micro steps of gradient accumulation +GLOBAL_BATCH_SIZE=$((($WORLD_SIZE * $MICRO_BATCH_SIZE) / ($TP_SIZE * $PP_SIZE) * 8)) +# GLOBAL_BATCH_SIZE=128 + +JOB_NAME="Baichuan_tp${TP_SIZE}_pp${PP_SIZE}_mbs${MICRO_BATCH_SIZE}_gpus${WORLD_SIZE}" + +LOAD_CHECKPOINT_PATH="PATH TO THE MODEL CHECKPOINT" +SAVE_CHECKPOINT_PATH="PATH TO SAVE MODEL CHECKPOINT" +TOKENIZER_PATH="PATH OR NAME FOR PRETRAINED TOKENIZER" +TENSORBOARD_DIR="TENSORBOARD DIRECTORY" + +TRAIN_ITERS=1000 +EVAL_ITERS=10 +EVAL_INTERVAL=1000 +SAVE_INTERVAL=100 +LOG_INTERVAL=1 + +# Setting --tensorboard-queue-size to 1 significantly slows down the training +options=" \ + --finetune \ + --sequence-parallel \ + --tensor-model-parallel-size ${TP_SIZE} \ + --pipeline-model-parallel-size ${PP_SIZE} \ + --num-layers 40 \ + --hidden-size 5120 \ + --num-attention-heads 40 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --no-position-embedding \ + --position-embedding-type alibi \ + --swiglu \ + --ffn-hidden-size 13696 \ + --disable-bias-linear \ + --RMSNorm \ + --attention-dropout 0 \ + --hidden-dropout 0 \ + --layernorm-epsilon 1e-6 \ + --causal-lm \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path $TOKENIZER_PATH \ + --make-vocab-size-divisible-by 1 \ + --trust-remote-code \ + --init-method-std 0.01 \ + --micro-batch-size ${MICRO_BATCH_SIZE} \ + --global-batch-size ${GLOBAL_BATCH_SIZE} \ + --train-iters ${TRAIN_ITERS} \ + --lr 6.0e-5 \ + --lr-decay-iters 10 \ + --lr-warmup-iters 5 \ + --min-lr 6.0e-6 \ + --override-opt_param-scheduler \ + --lr-decay-style cosine \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --clip-grad 1.0 \ + --weight-decay 0.1 \ + --overlapped-distributed-optimizer \ + --reduce-bucket-size=2e8 \ + --no-gradient-accumulation-fusion \ + --dataloader-type cyclic \ + --data-impl mmap \ + --data-path ${DATASET} \ + --split 98,2,0 \ + --eval-interval ${EVAL_INTERVAL} \ + --eval-iters ${EVAL_ITERS} \ + --save-interval ${SAVE_INTERVAL} \ + --save ${SAVE_CHECKPOINT_PATH} \ + --load ${LOAD_CHECKPOINT_PATH} \ + --no-load-optim \ + --log-interval ${LOG_INTERVAL} \ + --tensorboard-dir ${TENSORBOARD_DIR} \ + --tensorboard-queue-size 1000 \ + --log-timers-to-tensorboard \ + --log-batch-size-to-tensorboard \ + --log-validation-ppl-to-tensorboard \ + --job-name ${JOB_NAME} \ + --bf16 \ + --recompute-activations \ + --recompute-granularity selective \ + --use-flash-attn" + +torchrun --nproc_per_node=8 --master_port=29500 pretrain_baichuan.py ${options} diff --git a/megatron/arguments.py b/megatron/arguments.py index 513366c7d3..c2c52cc0b0 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -529,6 +529,9 @@ def _add_network_size_args(parser): group.add_argument('--max-position-embeddings', type=int, default=None, help='Maximum number of position embeddings to use. ' 'This is the size of position embedding.') + group.add_argument('--position-embedding-type', type=str, default='learned_absolute', + choices=['learned_absolute', 'rope', 'alibi'], + help='Position embedding type.') group.add_argument('--use-rotary-position-embeddings', action='store_true', help='Use rotary positional embeddings or not') group.add_argument('--rotary-percent', type=float, default=1.0, @@ -1104,6 +1107,8 @@ def _add_data_args(parser): help='Sentencepiece tokenizer model.') group.add_argument('--tokenizer-name-or-path', type=str, default=None, help='tokenizer model path for PretrainedFromHF.') + group.add_argument('--trust-remote-code', action='store_true', + help='Whether trust remote code when using PretrainedFromHF.') group.add_argument('--data-impl', type=str, default='infer', choices=['lazy', 'cached', 'mmap', 'infer'], help='Implementation of indexed datasets.') diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py index 168b65cae6..bfcdaff78a 100644 --- a/megatron/model/__init__.py +++ b/megatron/model/__init__.py @@ -6,6 +6,7 @@ from .bert_model import BertModel from .gpt_model import GPTModel from .llama_model import LLaMAModel +from .baichuan_model import BaichuanModel from .t5_model import T5Model from .language_model import get_language_model from .module import Float16Module diff --git a/megatron/model/baichuan_model.py b/megatron/model/baichuan_model.py new file mode 100644 index 0000000000..0ffc016b2f --- /dev/null +++ b/megatron/model/baichuan_model.py @@ -0,0 +1,169 @@ +# Copyright (c) 2023, ALIBABA CORPORATION. All rights reserved. + + +"""Baichuan model.""" + +import torch + +from megatron import get_args +from megatron.core import tensor_parallel +from .module import MegatronModule + +from .enums import AttnMaskType +from .language_model import parallel_lm_logits +from .language_model import get_language_model +from .utils import init_method_normal +from .utils import scaled_init_method_normal + + +def post_language_model_processing(lm_output, labels, logit_weights, + parallel_output, + fp16_lm_cross_entropy): + + # Output. Format [s b h] + output = parallel_lm_logits( + lm_output, + logit_weights, + parallel_output) + + if labels is None: + # [s b h] => [b s h] + return output.transpose(0,1).contiguous() + else: + # [b s] => [s b] + labels = labels.transpose(0,1).contiguous() + if fp16_lm_cross_entropy: + assert output.dtype == torch.half + loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels) + else: + loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels) + + # [s b] => [b, s] + loss = loss.transpose(0,1).contiguous() + return loss + + +class BaichuanModel(MegatronModule): + """Baichuan Language model.""" + + def __init__(self, + num_tokentypes=0, + parallel_output=True, + pre_process=True, + post_process=True): + args = get_args() + super(BaichuanModel, self).__init__(share_word_embeddings=not args.untie_embeddings_and_output_weights) + + self.parallel_output = parallel_output + self.pre_process = pre_process + self.post_process = post_process + self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy + self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights + self.sequence_parallel = args.sequence_parallel + self.padded_vocab_size = args.padded_vocab_size + + self.language_model, self._language_model_key = get_language_model( + num_tokentypes=num_tokentypes, + add_pooler=False, + encoder_attn_mask_type=AttnMaskType.causal, + init_method=init_method_normal(args.init_method_std), + scaled_init_method=scaled_init_method_normal(args.init_method_std, + args.num_layers), + pre_process=self.pre_process, + post_process=self.post_process) + + self.causal_lm = args.causal_lm + + if not args.untie_embeddings_and_output_weights and not self.causal_lm: + self.initialize_word_embeddings(init_method_normal) + + if self.causal_lm and self.post_process: + self.lm_head = torch.nn.Linear(args.hidden_size, args.padded_vocab_size, bias=False) + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.language_model.set_input_tensor(input_tensor) + + def _causal_lm_process(self, lm_output, labels): + if self.sequence_parallel: + lm_output = tensor_parallel.gather_from_sequence_parallel_region(lm_output, False) + lm_output = lm_output.transpose(0, 1) + logits = self.lm_head(lm_output) + + if labels is None: + return logits + else: + loss = None + # [invalid] Shift so that tokens < n predict n + # Do not need to shift here + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., :-1].contiguous() + # Flatten the tokens + loss_fct = torch.nn.CrossEntropyLoss(ignore_index=0) + shift_logits = shift_logits.view(-1, self.padded_vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + return loss + + def forward(self, input_ids, position_ids, attention_mask, + ret_input_ids=None, ret_position_ids=None, ret_attn_mask=None, + labels=None, tokentype_ids=None, inference_params=None): + lm_output = self.language_model( + input_ids, + position_ids, + attention_mask, + ret_input_ids=ret_input_ids, + ret_position_ids=ret_position_ids, + ret_attn_mask=ret_attn_mask, + inference_params=inference_params) + + if self.post_process: + if self.causal_lm: + return self._causal_lm_process(lm_output=lm_output, labels=labels) + else: + return post_language_model_processing( + lm_output, labels, + self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.word_embeddings_weight(), + self.parallel_output, + self.fp16_lm_cross_entropy) + else: + return lm_output + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint( + prefix=prefix, keep_vars=keep_vars) + # Save word_embeddings. + if (self.post_process + and not self.pre_process + and not self.untie_embeddings_and_output_weights + and not self.causal_lm): + state_dict_[self._word_embeddings_for_head_key] \ + = self.word_embeddings.state_dict(prefix=prefix, + keep_vars=keep_vars) + if self.post_process and self.causal_lm: + state_dict_['lm_head'] = self.lm_head.state_dict() + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + if self.causal_lm and self.post_process: + self.lm_head.load_state_dict(state_dict['lm_head'], strict=strict) + + # Load word_embeddings. + if self.post_process and \ + not self.pre_process \ + and not self.untie_embeddings_and_output_weights \ + and not self.causal_lm: + self.word_embeddings.load_state_dict( + state_dict[self._word_embeddings_for_head_key], strict=strict) + if self._language_model_key in state_dict: + state_dict = state_dict[self._language_model_key] + self.language_model.load_state_dict(state_dict, strict=strict) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index cd6a9dd444..425ea7e2cc 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -16,7 +16,7 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.rotary_pos_embedding import apply_rotary_pos_emb -from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu +from megatron.model.utils import attention_mask_func, alibi_mask_func, get_slopes, _buffered_future_mask, _gen_alibi_mask, openai_gelu, erf_gelu try: from einops import rearrange @@ -26,7 +26,15 @@ try: from flash_attn.flash_attn_interface import flash_attn_unpadded_func except ImportError: - flash_attn_unpadded_func = None + try: + from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func + except ImportError: + flash_attn_unpadded_func = None + +try: + from flash_attn.flash_attn_triton import flash_attn_func +except ImportError: + flash_attn_func = None """ We use the following notation throughout this file: h: hidden size @@ -234,11 +242,16 @@ def __init__(self, layer_number, coeff = self.layer_number self.norm_factor *= coeff + cur_mask_func = attention_mask_func + self.position_embedding_type = args.position_embedding_type + if args.position_embedding_type == "alibi": + cur_mask_func = alibi_mask_func + self.scale_mask_softmax = FusedScaleMaskSoftmax( self.fp16, self.bf16, self.attn_mask_type, args.masked_softmax_fusion, - attention_mask_func, + cur_mask_func, self.attention_softmax_in_fp32, coeff) @@ -253,7 +266,8 @@ def forward(self, query_layer, key_layer, # =================================== # Raw attention scores. [b, np, s, s] # =================================== - + + q_len = query_layer.size(0) # [b, np, sq, sk] output_size = (query_layer.size(1), query_layer.size(2), @@ -287,8 +301,12 @@ def forward(self, query_layer, key_layer, # =========================== # attention scores and attention mask [b, np, sq, sk] - attention_probs = self.scale_mask_softmax(attention_scores, - attention_mask) + if q_len == 1 and self.position_embedding_type == "alibi": # inference with cache + if len(attention_mask.size()) == 4: + attention_mask = attention_mask[:, :, -1:, :] + else: + attention_mask = attention_mask[:, -1:, :] + attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. @@ -356,7 +374,7 @@ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, self.softmax_scale = softmax_scale self.dropout_p = attention_dropout - def forward(self, q, k, v): + def forward(self, q, k, v, bias=None): """Implements the multihead softmax attention. Arguments --------- @@ -368,8 +386,9 @@ def forward(self, q, k, v): batch_size, seqlen_q = q.shape[0], q.shape[1] seqlen_k = k.shape[1] + if bias is None: + q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]] - q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]] cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q.device) @@ -387,13 +406,17 @@ def forward(self, q, k, v): device=q.device) self.dropout_p = 0 - output = flash_attn_unpadded_func( - q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, - self.dropout_p, - softmax_scale=self.softmax_scale, causal=is_causal - ) - - output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) + if bias is None: + output = flash_attn_unpadded_func( + q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, + self.dropout_p, + softmax_scale=self.softmax_scale, causal=is_causal + ) + else: + output = flash_attn_func( + q, k, v, bias, is_causal, self.softmax_scale + ) + # output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) return output @@ -417,6 +440,7 @@ def __init__(self, init_method, self.sequence_parallel = args.sequence_parallel self.use_flash_attn = args.use_flash_attn + self.position_embedding_type = args.position_embedding_type if self.use_flash_attn: if flash_attn_unpadded_func is None: raise ImportError('FlashAttention is not installed, please install with ' @@ -471,6 +495,15 @@ def __init__(self, init_method, self.core_attention = CoreAttention(self.layer_number, self.attn_mask_type) self.checkpoint_core_attention = args.recompute_granularity == 'selective' + + self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling + tensor_parallel_size = mpu.get_tensor_model_parallel_world_size() + self.hidden_size_per_partition = core.utils.divide(projection_size, + tensor_parallel_size) + self.hidden_size_per_attention_head = core.utils.divide( + projection_size, args.num_attention_heads) + self.num_attention_heads_per_partition = core.utils.divide( + args.num_attention_heads, tensor_parallel_size) if self.use_flash_attn: self.core_attention_flash = FlashSelfAttention( @@ -656,11 +689,18 @@ def forward(self, hidden_states, attention_mask, else: q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous() for x in (query_layer, key_layer, value_layer)] + cur_mask = None if not self.sequence_parallel: with tensor_parallel.get_cuda_rng_tracker().fork(): - context_layer = self.core_attention_flash(q, k, v) + if self.position_embedding_type == "alibi": + context_layer = self.core_attention_flash(q, k, v, bias=attention_mask) + else: + context_layer = self.core_attention_flash(q, k, v) else: - context_layer = self.core_attention_flash(q, k, v) + if self.position_embedding_type == "alibi": + context_layer = self.core_attention_flash(q, k, v, bias=attention_mask) + else: + context_layer = self.core_attention_flash(q, k, v) context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous() # ================= @@ -995,6 +1035,9 @@ def __init__(self, init_method, output_layer_init_method, self.input_tensor = None self.drop_path_rate = drop_path_rate self.transformer_impl = args.transformer_impl + self.position_embedding_type = args.position_embedding_type + self.num_attention_heads = args.num_attention_heads + self.params_dtype = args.params_dtype # Store activation checkpoiting flag. self.recompute_granularity = args.recompute_granularity @@ -1002,6 +1045,10 @@ def __init__(self, init_method, output_layer_init_method, self.recompute_num_layers = args.recompute_num_layers self.distribute_saved_activations = \ args.distribute_saved_activations and not args.sequence_parallel + + self.first_run = True + self.max_cache_pos = args.max_position_embeddings + self.alibi_mask = None self.sequence_parallel = args.sequence_parallel @@ -1222,11 +1269,107 @@ def set_input_tensor(self, input_tensor): forward_step_func""" self.input_tensor = input_tensor + def _build_alibi_tensor(self, tensor, max_seq_len, num_attention_heads): + # Copied from bigscience-workshop/Megatron-DeepSpeed + # Based on https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + """Returns tensor shaped + (1, num_attention_heads_per_partition, 1, max_seq_len), + """ + if self.training: + slopes = torch.tensor(get_slopes(num_attention_heads), device=torch.cuda.current_device()) + position_point = ( + torch.arange(max_seq_len) - max_seq_len + 1 + ).to(torch.cuda.current_device()) + position_point = ( + position_point.unsqueeze(0) + .unsqueeze(0) + .expand(num_attention_heads, max_seq_len, -1) + ) + diag = torch.diag(position_point[0]) + position_point = position_point - diag.unsqueeze(0).unsqueeze(0).transpose( + -1, -2 + ) + alibi = slopes.unsqueeze(1).unsqueeze(1) * position_point + mask = _buffered_future_mask( + tensor, max_seq_len, alibi, num_attention_heads + ) + else: + if self.first_run: + self.first_run = False + self.register_buffer( + "future_mask", + _gen_alibi_mask(num_attention_heads, self.max_cache_pos).to( + tensor + ), + persistent=False, + ) + if max_seq_len > self.max_cache_pos: + self.max_cache_pos = max_seq_len + self.register_buffer( + "future_mask", + _gen_alibi_mask(num_attention_heads, self.max_cache_pos).to( + tensor + ), + persistent=False, + ) + mask = self.future_mask[ + : num_attention_heads, :max_seq_len, :max_seq_len + ] + return mask + + def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, inference_params=None, rotary_pos_emb=None): # hidden_states: [s, b, h] + if self.position_embedding_type == "alibi": + # assert not args.use_flash_attn, \ + # 'ALiBi does not work with FlashAttention currently' + seq_len = hidden_states.shape[0] + if self.sequence_parallel: + seq_len = seq_len * mpu.get_tensor_model_parallel_world_size() + if self.training: + if ( + self.alibi_mask is None + or self.alibi_mask.shape[-1] != seq_len + ): + self.alibi_mask = self._build_alibi_tensor( + hidden_states, seq_len, self.num_attention_heads + ).to(torch.cuda.current_device()) + alibi_mask = self.alibi_mask + else: + alibi_mask = self._build_alibi_tensor(hidden_states, seq_len, self.num_attention_heads).to(torch.cuda.current_device()) + if self.params_dtype is torch.float16: + alialibi_maskbi = alibi_mask.to(torch.float16) + elif self.params_dtype is torch.bfloat16: + alibi_mask = alibi_mask.to(torch.bfloat16) # [head, seq_len, seq_len] + # Select the part of the tensor that corresponds to our tensor + # parallel index. + tp_world_size = mpu.get_tensor_model_parallel_world_size() + tp_index = mpu.get_tensor_model_parallel_rank() + alibi_mask = alibi_mask.reshape((tp_world_size, -1, *alibi_mask.shape[1:]))[tp_index] # [num_attention_heads/world, seq_len, max_seq_len] + if attention_mask is not None: + if len(attention_mask.shape) == 2: + expanded_mask = attention_mask.to(alibi_mask.dtype) + expanded_mask = torch.tril( + torch.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0) + ) * torch.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0) + else: + expanded_mask = attention_mask + bsz = hidden_states.size(1) + src_len, tgt_len = alibi_mask.size()[-2:] + expanded_mask = ( + expanded_mask + .expand(bsz, 1, src_len, tgt_len) + .to(alibi_mask.dtype) + ) + expanded_mask = expanded_mask.masked_fill( + expanded_mask.to(torch.bool), torch.finfo(alibi_mask.dtype).min + ) + attention_mask = expanded_mask + alibi_mask.unsqueeze(0) # [batch_size, head_size, seq_len, seq_len] + else: + attention_mask = alibi_mask # Checks. if inference_params: assert self.recompute_granularity is None, \ diff --git a/megatron/model/utils.py b/megatron/model/utils.py index cf3727c02b..0464e298b6 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -31,6 +31,54 @@ def attention_mask_func(attention_scores, attention_mask): return attention_scores +def alibi_mask_func(attention_scores, attention_mask): + attention_scores = attention_scores + attention_mask + return attention_scores + + +def get_slopes(n): + def get_slopes_power_of_2(n): + start = (2 ** (-2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio ** i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes( + 2 * closest_power_of_2, + )[0::2][:n - closest_power_of_2] + ) + + +def _fill_with_neg_inf(t): + """FP16-compatible function that fills a tensor with -inf.""" + return t.float().fill_(float("-inf")).type_as(t) + + +def _buffered_future_mask(tensor, maxpos, alibi, attn_heads): + _future_mask = torch.triu(_fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1).to(torch.cuda.current_device()) + _future_mask = _future_mask.unsqueeze(0) + alibi + new_future_mask = _future_mask.to(tensor) + return new_future_mask[: tensor.shape[0] * attn_heads, :maxpos, :maxpos] + + +def _gen_alibi_mask(num_attention_heads, max_seq_len): + slopes = torch.tensor(get_slopes(num_attention_heads), device=torch.cuda.current_device()) + alibi = ( + slopes.unsqueeze(1).unsqueeze(1) + * torch.arange(max_seq_len, device=torch.cuda.current_device()).unsqueeze(0).unsqueeze(0).expand( + num_attention_heads, -1, -1) + ) + # alibi = alibi.unsqueeze(0) + alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_seq_len, max_seq_len])), 1).to(torch.cuda.current_device()) # [max_seq_len, max_seq_len] + alibi_mask = alibi_mask.unsqueeze(0) + alibi # [num_attention_heads, max_seq_len, max_seq_len] + return alibi_mask + + def get_linear_layer(rows, columns, init_method): """Simple linear layer with weight initialization.""" layer = torch.nn.Linear(rows, columns) diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py index d87a7c98d9..9f259c239c 100644 --- a/megatron/tokenizer/tokenizer.py +++ b/megatron/tokenizer/tokenizer.py @@ -43,7 +43,7 @@ def build_tokenizer(args): tokenizer = _NullTokenizer(args.vocab_size) elif args.tokenizer_type == 'PretrainedFromHF': assert args.tokenizer_name_or_path is not None - tokenizer = _AutoTokenizer(args.tokenizer_name_or_path, vocab_extra_ids=args.vocab_extra_ids) + tokenizer = _AutoTokenizer(args.tokenizer_name_or_path, trust_remote_code=args.trust_remote_code, vocab_extra_ids=args.vocab_extra_ids) else: raise NotImplementedError('{} tokenizer is not ' 'implemented.'.format(args.tokenizer_type)) @@ -545,13 +545,13 @@ def additional_special_tokens_ids(self): class _AutoTokenizer(AbstractTokenizer): """AutoTokenizer for Hf Pretrained model loading.""" - def __init__(self, tokenizer_name_or_path, vocab_extra_ids=0): + def __init__(self, tokenizer_name_or_path, trust_remote_code=False, vocab_extra_ids=0): name = tokenizer_name_or_path super().__init__(name) hf_tokenizer_kwargs = {} if vocab_extra_ids > 0: hf_tokenizer_kwargs["additional_special_tokens"] = [f"" for _id in range(vocab_extra_ids)] - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, **hf_tokenizer_kwargs) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=trust_remote_code,**hf_tokenizer_kwargs) self.tokenizer.pad_token = self.tokenizer.eos_token self.encoder = self.tokenizer.get_vocab() self.decoder = {v: k for k, v in self.encoder.items()} diff --git a/pretrain_baichuan.py b/pretrain_baichuan.py new file mode 100644 index 0000000000..f5fc978a2f --- /dev/null +++ b/pretrain_baichuan.py @@ -0,0 +1,122 @@ +# Copyright (c) 2023, ALIBABA CORPORATION. All rights reserved. + +"""Pretrain Baichuan""" +import os + +import torch +from functools import partial +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_timers +from megatron import get_tokenizer +from megatron.core import tensor_parallel +from megatron.core.enums import ModelType +from megatron.data.gpt_dataset import build_train_valid_test_datasets +from megatron.model import BaichuanModel +from megatron.training import pretrain +from megatron.utils import get_ltor_masks_and_position_ids +from megatron.utils import average_losses_across_data_parallel_group + + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0('building Baichuan model ...') + model = BaichuanModel( + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process + ) + return model + + +def get_batch(data_iterator): + """Generate a batch""" + args = get_args() + tokenizer = get_tokenizer() + + # Items and their type. + keys = ['text'] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + data_b = tensor_parallel.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and postition ids. + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss) + + return tokens, labels, loss_mask, attention_mask, position_ids + +def loss_func(loss_mask, output_tensor): + loss = output_tensor.float() + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + + return loss, {'lm loss': averaged_loss[0]} + + +def forward_step(data_iterator, model): + """Forward step.""" + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + timers('batch-generator').stop() + + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) + + return output_tensor, partial(loss_func, loss_mask) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + + print_rank_0('> building train, validation, and test datasets ' + 'for Baichuan ...') + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + data_prefix=args.data_path, + data_impl=args.data_impl, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=(not args.mmap_warmup), + train_data_prefix=args.train_data_path, + valid_data_prefix=args.valid_data_path, + test_data_prefix=args.test_data_path) + print_rank_0("> finished creating Baichuan datasets ...") + + return train_ds, valid_ds, test_ds + + +if __name__ == "__main__": + os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' + # os.environ['NCCL_NET_GDR_READ'] = '0' + # os.environ['NCCL_NET_GDR_LEVEL'] = '0' + os.environ['NCCL_MIN_NCHANNELS'] = '16' + + pretrain(train_valid_test_datasets_provider, model_provider, + ModelType.encoder_or_decoder, + forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'} + ) diff --git a/tools/checkpoint_conversion/baichuan_checkpoint_conversion.py b/tools/checkpoint_conversion/baichuan_checkpoint_conversion.py new file mode 100644 index 0000000000..292b241000 --- /dev/null +++ b/tools/checkpoint_conversion/baichuan_checkpoint_conversion.py @@ -0,0 +1,876 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import shutil +import json +import os +import re +import sys +import types + +import torch + +from transformers import AutoConfig +from transformers.modeling_utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, shard_checkpoint + + +def add_checkpointing_args(parser): + parser.add_argument("--megatron-path", type=str, default=None, help="Base directory of Megatron repository") + parser.add_argument( + "--convert_checkpoint_from_megatron_to_transformers", + action="store_true", + help=( + "If True, convert a Megatron checkpoint to a Transformers checkpoint. " + "If False, convert a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--load_path", + type=str, + required=True, + help="Path to the checkpoint to convert.", + ) + parser.add_argument( + "--save_path", + type=str, + required=True, + help="Path to the converted checkpoint.", + ) + parser.add_argument( + "--config_path", + type=str, + help="Path to the converted checkpoint.", + default="", + ) + parser.add_argument("--print-checkpoint-structure", action="store_true") + + return parser + + +def add_megatron_checkpoint_args(parser): + parser.add_argument( + "--target_tensor_model_parallel_size", + type=int, + default=1, + help=( + "The tensor model parallel size of the converted checkpoint. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--target_pipeline_model_parallel_size", + type=int, + default=1, + help=( + "The pipeline model parallel size of the converted checkpoint. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--target_data_parallel_size", + type=int, + default=1, + help=( + "The data parallel size of the converted checkpoint. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--target_params_dtype", + type=str, + default="fp32", + help=( + "The dtype of the converted checkpoint. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--make_vocab_size_divisible_by", + type=int, + default=128, + help=( + "Pad the vocab size to be divisible by this value. " + "This is added for computational efficieny reasons. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--use_distributed_optimizer", + action="store_true", + help=( + "If True, use the distributed optimizer. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + return parser + + +def add_transformers_checkpoint_args(parser): + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help=( + "The name of the pre-trained tokenizer to save. " + "If not None, the tokenizer will be saved. " + "Only used when converting a Megatron checkpoint to a Transformers checkpoint." + ), + ) + parser.add_argument( + "--max_shard_size", + type=str, + default="60GB", + help=( + "The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size " + "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`). " + "Only used when converting a Megatron checkpoint to a Transformers checkpoint." + ), + ) + + return parser + + +# The simple map of names for "automated" rules. +megatron_to_transformers = { + "self_attention.dense": ".self_attn.o_proj.", + "mlp.dense_4h_to_h": ".mlp.down_proj.", +} + +tensor_parallel_params = [ + # megatron-lm layers to merge across tp ranks + "self_attention.query_key_value.weight", + "self_attention.query_key_value.bias", + "self_attention.dense.weight", + "mlp.dense_h_to_4h.weight", + "mlp.dense_h_to_4h.bias", + "mlp.dense_4h_to_h.weight", + # deprecated + "attention.query_key_value.weight", + "attention.query_key_value.bias", + "attention.dense.weight", + # transformers layers to split across tp ranks + "attn.c_attn.weight", + "attn.c_attn.bias", + "attn.c_proj.weight", + "mlp.c_fc.weight", + "mlp.c_fc.bias", + "mlp.c_proj.weight", + 'self_attn.q_proj.weight', + 'self_attn.k_proj.weight', + 'self_attn.v_proj.weight', + 'self_attn.o_proj.weight', + 'mlp.down_proj.weight', + 'mlp.up_proj.weight', + 'mlp.gate_proj.weight', + 'self_attn.W_pack.weight' +] + + +def recursive_print(name, val, spaces=0): + """ + Recursively print the structure of a checkpoint. This function is taken from `convert_megatron_gpt2_checkpoint.py` + + Args: + name (str): the name of the current tensor parameter + val (Tuple(int)): the shape of the current tensor parameter + spaces (int): the number of spaces to print before the output for a nested structure + """ + # Format the message. + if name is None: + msg = None + else: + fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}" + msg = fmt.format(name) + + # Print and recurse (if needed). + if isinstance(val, dict): + if msg is not None: + print(msg) + for k in val.keys(): + recursive_print(k, val[k], spaces + 2) + elif isinstance(val, torch.Tensor): + print(msg, ":", val.size()) + else: + print(msg, ":", val) + + +def merge_transformers_sharded_states(path, num_checkpoints): + """ + Merge sharded checkpoints from transformers into a single checkpoint. + + Args: + path (str): the path to the sharded checkpoints + num_checkpoints (int): the number of checkpoints to merge + """ + state_dict = {} + for i in range(1, num_checkpoints + 1): + print('loading', i, ':', num_checkpoints + 1) + checkpoint_path = os.path.join(path, f"pytorch_model-{i:05d}-of-{num_checkpoints:05d}.bin") + if not os.path.exists(checkpoint_path): + checkpoint_path = os.path.join(path, f"pytorch_model-{i}-of-{num_checkpoints}.bin") + assert os.path.exists(checkpoint_path), f"Cannot find checkpoint {checkpoint_path}" + current_chunk = torch.load(checkpoint_path, map_location="cpu") + state_dict.update(current_chunk) + return state_dict + + +def get_megatron_sharded_states(load_path, tp_size, pp_size, pp_rank): + """ + Get sharded checkpoints from NVIDIA Megatron-LM checkpoint based on the provided tensor parallel size, pipeline + parallel size and pipeline parallel rank. + + Args: + args (argparse.Namespace): the arguments to the script + tp_size (int): the tensor parallel size + pp_size (int): the pipeline parallel size + pp_rank (int): the pipeline parallel rank + """ + tp_state_dicts = [] + for i in range(tp_size): + possible_sub_dir_names = [ + f"mp_rank_{i:02d}" if pp_size == 1 else f"mp_rank_{i:02d}_{pp_rank:03d}", + f"mp_rank_{i:02d}_dp_000" if pp_size == 1 else f"mp_rank_{i:02d}_{pp_rank:03d}_dp_000" + ] + sub_dir_name = None + for p in possible_sub_dir_names: + if os.path.exists(os.path.join(load_path, p)): + sub_dir_name = p + break + assert sub_dir_name is not None, f"Cannot find sub dir in {possible_sub_dir_names}" + checkpoint_path = os.path.join(load_path, sub_dir_name, 'model_optim_rng.pt') + state_dict = torch.load(checkpoint_path, map_location="cpu") + tp_state_dicts.append(state_dict) + return tp_state_dicts + + +def get_element_from_dict_by_path(d, path): + """ + Get element from dictionary by path. If element is not present, recursively add empty dictionaries. + + Args: + d (dict): the dictionary to get the element from + path (list): the path to the element which is delimited by "." + """ + path = path.split(".") + for k in path: + if k not in d: + d[k] = {} + d = d[k] + return d + + +def copy_tokenizer(args): + os.makedirs(args.save_path, exist_ok=True) + tokenizer_dir = args.load_path + if os.path.exists(os.path.join(args.load_path, 'tokenizer')): + tokenizer_dir = os.path.join(args.load_path, 'tokenizer') + file_list = os.listdir(tokenizer_dir) + for f in file_list: + if 'token' in f: + shutil.copyfile(os.path.join(tokenizer_dir, f), os.path.join(args.save_path, f)) + + +def convert_checkpoint_from_megatron_to_transformers(args): + """ + Convert NVIDIA Megatron-LM checkpoint to HuggingFace Transformers checkpoint. This handles Megatron checkpoints + with different tensor parallelism and pipeline parallelism sizes. It saves the converted checkpoint into shards + using HuggingFace Transformers checkpoint sharding functionality. + + Args: + args (argparse.Namespace): the arguments to the script + """ + # Search in directory above this + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) + if args.megatron_path is not None: + sys.path.insert(0, args.megatron_path) + + # Load Megatron-LM checkpoint arguments from the state dict + sub_dirs = os.listdir(args.load_path) + release = False + if 'latest_checkpointed_iteration.txt' in sub_dirs: + with open(os.path.join(args.load_path, 'latest_checkpointed_iteration.txt')) as f: + latest_ckpt = f.readline().strip() + print(f"latest checkpoint: {latest_ckpt}") + if isinstance(latest_ckpt, bytearray): + latest_ckpt = latest_ckpt.decode("utf-8") + try: + iteration = int(latest_ckpt) + except ValueError: + release = (latest_ckpt == "release") + if not release: + raise ValueError(f"Invalid latest checkpoint: {latest_ckpt}") + for sub_dir in sub_dirs: + if latest_ckpt in sub_dir: + latest_ckpt = sub_dir + break + else: + raise ValueError('Cannot find latest ckpt!') + possible_state_paths = [os.path.join(args.load_path, latest_ckpt), + os.path.join(args.load_path, latest_ckpt, + 'iter_' + str(iteration) if not release else 'release')] + state_path = None + for p in possible_state_paths: + if os.path.exists(p): + state_path = p + print(f"Loading Megatron-LM checkpoint arguments from: {state_path}") + break + assert state_path is not None, f"Cannot find state path in {possible_state_paths}" + possible_sub_dirs = ["mp_rank_00", "mp_rank_00_000", "mp_rank_00_dp_000", "mp_rank_00_000_dp_000"] + state_dirs = os.listdir(state_path) + for sub_dir in possible_sub_dirs: + if sub_dir in state_dirs: + rank0_checkpoint_path = os.path.join(state_path, sub_dir, 'model_optim_rng.pt') + break + print(f"Loading Megatron-LM checkpoint arguments from: {rank0_checkpoint_path}") + state_dict = torch.load(rank0_checkpoint_path, map_location="cpu") + megatron_args = state_dict.get("args", None) + if megatron_args is None: + raise ValueError( + "Megatron-LM checkpoint does not contain arguments. This utility only supports Megatron-LM checkpoints" + " containing all the megatron arguments. This is because it loads all config related to model" + " architecture, the tensor and pipeline model parallel size from the checkpoint insead of user having to" + " manually specify all the details. Please save Megatron-LM checkpoint along with all the megatron" + " arguments to use this utility." + ) + + # Create Transformers GPT2 config from Megatron-LM arguments + vocab_size = megatron_args.padded_vocab_size + + # params dtype + if args.target_params_dtype == "fp16": + dtype = torch.float16 + elif args.target_params_dtype == "bf16": + dtype = torch.bfloat16 + else: + dtype = torch.float32 + + # config = AutoConfig.from_pretrained(args.config_path) + config = AutoConfig.from_pretrained(args.config_path, trust_remote_code=True) + config.bos_token_id = 1 + config.eos_token_id = 2 + config.hidden_act = 'silu' + config.hidden_size = megatron_args.hidden_size + config.intermediate_size = megatron_args.ffn_hidden_size + config.initializer_range = 0.02 + config.model_max_length = megatron_args.max_position_embeddings + config.model_type = 'baichuan' + config.num_attention_heads = megatron_args.num_attention_heads + config.num_hidden_layers = megatron_args.num_layers + config.pad_token_id = 0 + config.rms_norm_eps = 1e-6 + config.torch_dtype = 'bfloat16' + config.transformers_version = '4.29.2' + config.use_cache = True + config.vocab_size = vocab_size + config.architectures = ["BaichuanForCausalLM"] + + output_state_dict = {} + + tp_size = megatron_args.tensor_model_parallel_size + pp_size = megatron_args.pipeline_model_parallel_size + + # The regex to extract layer names. + layer_re = re.compile(r"layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") + + # Convert. + print("Converting") + + # Embeddings + print("Converting embeddings") + tp_state_dicts = get_megatron_sharded_states(state_path, tp_size, pp_size, 0) + + # Convert and store the word embeddings. + word_embeddings = torch.cat( + [ + get_element_from_dict_by_path( + tp_state_dicts[tp_rank], "model.language_model.embedding.word_embeddings.weight" + ) + for tp_rank in range(tp_size) + ], + dim=0, + ) + word_embeddings = word_embeddings[:vocab_size].to(dtype).clone().detach().contiguous() + output_state_dict["model.embed_tokens.weight"] = word_embeddings + + # Transformer Layers + print("Converting transformer layers") + # The hidden_size per head. + hidden_size_per_head = config.hidden_size // config.num_attention_heads + num_layers = config.num_hidden_layers // pp_size + layer_idx = 0 + for pp_rank in range(pp_size): + if pp_size > 0: + print(f"Converting pipeline parallel rank {pp_rank}") + tp_state_dicts = get_megatron_sharded_states(state_path, tp_size, pp_size, pp_rank) + # The transformer. + path = "model.language_model.encoder" + + # Extract the layers. + for key, val in get_element_from_dict_by_path(tp_state_dicts[0], path).items(): + # Match the name. + m = layer_re.match(key) + # Stop if that's not a layer + if m is None: + break + + # The index of the layer. + layer_idx = int(m.group(1)) + pp_rank * num_layers + # The name of the operation. + op_name = m.group(2) + # Is it a weight or a bias? + weight_or_bias = m.group(3) + # The name of the layer. + layer_name = f"model.layers.{layer_idx}" + + if op_name + "." + weight_or_bias not in tensor_parallel_params: + params = val.to(dtype) + else: + dim = 1 if op_name in ["self_attention.dense", "mlp.dense_4h_to_h"] else 0 + params = torch.cat( + [val] + + [ + get_element_from_dict_by_path(tp_state_dicts[tp_rank], f"{path}")[key] + for tp_rank in range(1, tp_size) + ], + dim=dim, + ).to(dtype) + + # For layernorm(s), simply store the layer norm. + if op_name.endswith("layernorm"): + ln_name = "input_layernorm" if op_name.startswith("input") else "post_attention_layernorm" + output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = params + + # Split QKV packed weights + elif op_name == "self_attention.query_key_value" and weight_or_bias == "weight": + params_per_tp = params.chunk(dim=0, chunks=megatron_args.tensor_model_parallel_size) + q = torch.empty(0) + k = torch.empty(0) + v = torch.empty(0) + for t in params_per_tp: + qp, kp, vp = t.chunk(3) + q = torch.cat([q, qp]) + k = torch.cat([k, kp]) + v = torch.cat([v, vp]) + output_state_dict[layer_name + ".self_attn.W_pack.weight"] = torch.cat([q, k, v], dim=0).to(dtype).clone().detach().contiguous() + # output_state_dict[layer_name + ".self_attn.q_proj.weight"] = q.to(dtype).clone().detach().contiguous() + # output_state_dict[layer_name + ".self_attn.k_proj.weight"] = k.to(dtype).clone().detach().contiguous() + # output_state_dict[layer_name + ".self_attn.v_proj.weight"] = v.to(dtype).clone().detach().contiguous() + + elif op_name == "mlp.dense_h_to_4h" and weight_or_bias == "weight": + params_per_tp = params.chunk(dim=0, chunks=megatron_args.tensor_model_parallel_size) + gate = torch.empty(0) + up = torch.empty(0) + for t in params_per_tp: + gatep, upp = t.chunk(2) + gate = torch.cat([gate, gatep]) + up = torch.cat([up, upp]) + output_state_dict[layer_name + ".mlp.gate_proj.weight"] = gate.to(dtype).clone().detach().contiguous() + output_state_dict[layer_name + ".mlp.up_proj.weight"] = up.to(dtype).clone().detach().contiguous() + + # Transpose the weights. + elif weight_or_bias == "weight": + out_name = megatron_to_transformers[op_name] + output_state_dict[layer_name + out_name + "weight"] = params + + # Copy the bias. + elif weight_or_bias == "bias": + out_name = megatron_to_transformers[op_name] + output_state_dict[layer_name + out_name + "bias"] = params + + # rotary_base = 10000 + # inv_freq = 1.0 / (rotary_base ** (torch.arange(0, hidden_size_per_head, 2).float() / hidden_size_per_head)) + # output_state_dict[layer_name + '.self_attn.rotary_emb.inv_freq'] = inv_freq + + if config.num_hidden_layers != (layer_idx + 1): + raise ValueError(f"Expected {config.num_hidden_layers} layers but found {layer_idx + 1}") + + # The final layernorm. + print("Converting final layernorm") + params = get_element_from_dict_by_path(tp_state_dicts[0], str(path)) + output_state_dict["model.norm.weight"] = params["final_layernorm.weight"].to(dtype) + + # For LM head, transformers' wants the matrix to weight embeddings. + print("Converting LM head") + output_state_dict["lm_head.weight"] = state_dict['model']['lm_head']['weight'].to(dtype) + + # It should be done! + print("Conversion from Megatron-LM to Transformers is done!") + + # Print the structure of converted state dict. + if args.print_checkpoint_structure: + recursive_print(None, output_state_dict) + + # Save tokenizer based on args + copy_tokenizer(args=args) + + # Store the config to file. + print("Saving config") + config.save_pretrained(args.save_path) + + # Store the state_dict to file. + max_shard_size = int(args.max_shard_size) if args.max_shard_size.isdigit() else args.max_shard_size + shards, index = shard_checkpoint(output_state_dict, max_shard_size=max_shard_size) + + # Save the model + for shard_file, shard in shards.items(): + torch.save(shard, os.path.join(args.save_path, shard_file)) + + if index is None: + print(f"Model weights saved in {os.path.join(args.save_path, WEIGHTS_NAME)}") + else: + save_index_file = os.path.join(args.save_path, WEIGHTS_INDEX_NAME) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + print( + f"The model is bigger than the maximum size per checkpoint ({args.max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + +def convert_checkpoint_from_transformers_to_megatron(args): + """ + Convert a checkpoint from HuggingFace Transformers to Megatron-LM. This allows converted checkpoints with variable + tensor parallelism and pipeline parallelism sizes. It takes as input a checkpoint from HuggingFace Transformers + which can have multiple shards. + + Args: + args (argparse.Namespace): the arguments to the script + + """ + os.makedirs(args.save_path, exist_ok=True) + # Search in directory above this + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) + if args.megatron_path is not None: + print("args.megatron_path", args.megatron_path) + sys.path.insert(0, args.megatron_path) + try: + from megatron.tokenizer.tokenizer import _vocab_size_with_padding + from megatron.fs_utils import create_read_file_system + # except ModuleNotFoundError: + except Exception as e: + print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.") + print(e) + exit(1) + + # load the transformers model state dict and config + sub_dirs = [x for x in os.listdir(args.load_path) if x.startswith("pytorch_model")] + if len(sub_dirs) == 1: + checkpoint_name = "pytorch_model.bin" + state_dict = torch.load(os.path.join(args.load_path, checkpoint_name), map_location="cpu") + else: + num_checkpoints = len(sub_dirs) - 1 + state_dict = merge_transformers_sharded_states(args.load_path, num_checkpoints) + + config = AutoConfig.from_pretrained(args.load_path, trust_remote_code=True) + + # Saving the tracker file + tracker_filepath = os.path.join(args.save_path, "latest_checkpointed_iteration.txt") + with open(tracker_filepath, "w") as f: + f.write("release") + + # create `release` dir in args.load_path + release_dir = os.path.join(args.save_path, "release") + os.makedirs(release_dir, exist_ok=True) + + # megatron args + megatron_args = { + "orig_vocab_size": config.vocab_size, + "max_position_embeddings": config.model_max_length, + "hidden_size": config.hidden_size, + "num_layers": config.num_hidden_layers, + "num_attention_heads": config.num_attention_heads, + "ffn_hidden_size": config.intermediate_size, + "tensor_model_parallel_size": args.target_tensor_model_parallel_size, + "pipeline_model_parallel_size": args.target_pipeline_model_parallel_size, + "data_parallel_size": args.target_data_parallel_size, + "make_vocab_size_divisible_by": args.make_vocab_size_divisible_by, + "rank": 0, + "tokenizer_type": "GPT2BPETokenizer", + } + + margs = types.SimpleNamespace() + for k, v in megatron_args.items(): + setattr(margs, k, v) + + # params dtype + if args.target_params_dtype == "fp16": + dtype = torch.float16 + elif args.target_params_dtype == "bf16": + dtype = torch.bfloat16 + else: + dtype = torch.float32 + setattr(margs, "params_dtype", dtype) + + # save dummy optim state dict + dummy_optim_state_dict = {} + dummy_optim_state_dict["optimizer"] = { + "step": 0, + "param_groups": [ + { + "lr": 0.0, + "beta1": 0.0, + "beta2": 0.0, + "eps": 0.0, + "weight_decay": 0.0, + "correct_bias": False, + "params": [], + } + ], + } + if args.use_distributed_optimizer: + for i in range(args.target_pipeline_model_parallel_size): + for j in range(args.target_tensor_model_parallel_size): + for k in range(args.target_data_parallel_size): + if args.target_pipeline_model_parallel_size == 1: + checkpoint_dir = f"mp_rank_{j:02d}_{i:03d}" + else: + checkpoint_dir = f"mp_rank_{j:02d}_{i:03d}_{k:03d}" + checkpoint_dir = os.path.join(release_dir, checkpoint_dir) + os.makedirs(checkpoint_dir, exist_ok=True) + torch.save( + dummy_optim_state_dict, + os.path.join(checkpoint_dir, "optim.pt"), + ) + + # Convert. + print("Converting") + output_state_dict = [] + for i in range(args.target_tensor_model_parallel_size): + output_state_dict.append({}) + + # Embedding layer + print("converting embedding layer") + # pos_embedding = state_dict["transformer.wpe.weight"].to(dtype) + word_embedding = state_dict["model.embed_tokens.weight"].to(dtype) + orig_vocab_size = config.vocab_size + padded_vocab_size = _vocab_size_with_padding(orig_vocab_size, margs) + setattr(margs, "padded_vocab_size", padded_vocab_size) + # Cut out extra padding we don't need + if orig_vocab_size > padded_vocab_size: + full_word_embed = word_embedding[0:padded_vocab_size, :] + # Expanding embedding to larger size by replicating final entry + elif orig_vocab_size < padded_vocab_size: + padding_size = padded_vocab_size - orig_vocab_size + full_word_embed = torch.cat((word_embedding, word_embedding[-1].unsqueeze(0).expand(padding_size, -1))) + # Same size! + else: + full_word_embed = word_embedding + + # Split into new tensor model parallel sizes + out_word_embed = torch.chunk(full_word_embed, args.target_tensor_model_parallel_size, dim=0) + for i in range(args.target_tensor_model_parallel_size): + word_emb_dict = get_element_from_dict_by_path( + output_state_dict[i], "model.language_model.embedding.word_embeddings" + ) + word_emb_dict["weight"] = out_word_embed[i] + + # Transformer layers + print("converting transformer layers") + if config.num_hidden_layers % args.target_tensor_model_parallel_size != 0: + raise ValueError( + f"Number of layers ({config.num_hidden_layers}) must be divisible by number of tensor parallelism" + f" ({args.target_tensor_model_parallel_size})" + ) + num_layers = config.num_hidden_layers // args.target_pipeline_model_parallel_size + + layer_re = re.compile(r"model.layers\.(\d+)\.([a-zA-Z0-9_.]+)\.([A-Za-z]+)") + # The number of heads. + heads = config.num_attention_heads + # The hidden_size per head. + hidden_size_per_head = config.hidden_size // config.num_attention_heads + weight_or_bias = "weight" + print("args.target_pipeline_model_parallel_size", args.target_pipeline_model_parallel_size) + for pp_rank in range(args.target_pipeline_model_parallel_size): + layer_offset = pp_rank * num_layers + if pp_rank > 0: + output_state_dict = [] + for i in range(args.target_tensor_model_parallel_size): + output_state_dict.append({}) + print("num_layers", num_layers) + for layer in range(num_layers): + pp_layer_id = layer + layer_offset + layers_to_copy = [ + layer_name + for layer_name in state_dict.keys() + if layer_name.startswith(f"model.layers.{pp_layer_id}.") + ] + qkv_weight_to_combine = {} + mlp_weight_to_combine = {} + for layer_name in layers_to_copy: + m = layer_re.match(layer_name) + # Stop if that's not a layer + if m is None: + break + + # The index of the layer. + _ = int(m.group(1)) + # The name of the operation. + op_name = m.group(2) + # Is it a weight or a bias? + weight_or_bias = m.group(3) + params = state_dict[layer_name].to(dtype) + # handle layernorm + if op_name.endswith("layernorm"): + # out_name = "input_layernorm" if op_name.endswith("1") else "post_attention_layernorm" + out_name = op_name + layer_name = f"layers.{layer}.{out_name}.{weight_or_bias}" + + elif 'self_attn.o_proj' in op_name and weight_or_bias == 'weight': + layer_name = f"layers.{layer}.self_attention.dense.{weight_or_bias}" + + # handle attention K, V, Q weights + elif 'self_attn.W_pack' in op_name and weight_or_bias == 'weight': + params_tmp = params.unflatten(0, (3, config.hidden_size)) + q_weights = params_tmp[0].chunk(args.target_tensor_model_parallel_size, dim=0) + k_weights = params_tmp[1].chunk(args.target_tensor_model_parallel_size, dim=0) + v_weights = params_tmp[2].chunk(args.target_tensor_model_parallel_size, dim=0) + result_weights = [] + for idx in range(len(q_weights)): + partition_weight = torch.cat([q_weights[idx], k_weights[idx], v_weights[idx]]) + print("partition_weight", partition_weight.shape) + result_weights.append(partition_weight) + + params = torch.cat(result_weights) + layer_name = f"layers.{layer}.self_attention.query_key_value.{weight_or_bias}" + # elif op_name.startswith("self_attn") and weight_or_bias == "weight": + # # transformers stores D X (3*D) but Megatron-LM expects (3*D) X D. + # # params = params.transpose(0, 1).contiguous() + # assert (len(qkv_weight_to_combine) != 3) + # if 'q_proj' in op_name: + # qkv_weight_to_combine['q_proj'] = params + # elif 'k_proj' in op_name: + # qkv_weight_to_combine['k_proj'] = params + # elif 'v_proj' in op_name: + # qkv_weight_to_combine['v_proj'] = params + # + # if len(qkv_weight_to_combine) == 3: + # q_weights = qkv_weight_to_combine['q_proj'].chunk(args.target_tensor_model_parallel_size, dim=0) + # k_weights = qkv_weight_to_combine['k_proj'].chunk(args.target_tensor_model_parallel_size, dim=0) + # v_weights = qkv_weight_to_combine['v_proj'].chunk(args.target_tensor_model_parallel_size, dim=0) + # result_weights = [] + # for idx in range(len(q_weights)): + # partition_weight = torch.cat([q_weights[idx], k_weights[idx], v_weights[idx]]) + # result_weights.append(partition_weight) + # + # params = torch.cat(result_weights) + # layer_name = f"layers.{layer}.self_attention.query_key_value.{weight_or_bias}" + # else: + # continue + + elif op_name.startswith("mlp") and weight_or_bias == "weight": + if 'down_proj' in op_name: + layer_name = f"layers.{layer}.mlp.dense_4h_to_h.{weight_or_bias}" + elif 'gate_proj' in op_name: + assert (len(mlp_weight_to_combine) != 2) + mlp_weight_to_combine['gate_proj'] = params + elif 'up_proj' in op_name: + assert (len(mlp_weight_to_combine) != 2) + mlp_weight_to_combine['up_proj'] = params + + if 'down_proj' not in op_name and len(mlp_weight_to_combine) == 2: + gate_weights = mlp_weight_to_combine['gate_proj'].chunk(args.target_tensor_model_parallel_size, + dim=0) + up_weights = mlp_weight_to_combine['up_proj'].chunk(args.target_tensor_model_parallel_size, + dim=0) + result_weights = [] + for idx in range(len(gate_weights)): + partition_weight = torch.cat([gate_weights[idx], up_weights[idx]]) + result_weights.append(partition_weight) + + params = torch.cat(result_weights) + layer_name = f"layers.{layer}.mlp.dense_h_to_4h.{weight_or_bias}" + elif 'down_proj' not in op_name and len(mlp_weight_to_combine) < 2: + continue + + else: + continue + + if op_name + "." + weight_or_bias in tensor_parallel_params: + dim = 1 if op_name in [ + "self_attn.o_proj", "mlp.down_proj"] else 0 + params = torch.chunk( + params, args.target_tensor_model_parallel_size, dim=dim) + + for i in range(args.target_tensor_model_parallel_size): + params_dict = get_element_from_dict_by_path( + output_state_dict[i], "model.language_model.encoder") + params_dict[layer_name] = ( + params[i].clone().detach().contiguous() if ( + op_name + "." + weight_or_bias in tensor_parallel_params) + else params.clone().detach().contiguous() + ) + + if pp_rank == args.target_pipeline_model_parallel_size - 1: + # handle final layernorm + params = state_dict[f"model.norm.weight"].to(dtype) + layer_name = f"final_layernorm.{weight_or_bias}" + for i in range(args.target_tensor_model_parallel_size): + params_dict = get_element_from_dict_by_path( + output_state_dict[i], "model.language_model.encoder") + params_dict[layer_name] = params.clone().detach().contiguous() + + # add the LM head + for i in range(args.target_tensor_model_parallel_size): + params_dict = get_element_from_dict_by_path( + output_state_dict[i], "model.lm_head") + params_dict["weight"] = state_dict['lm_head.weight'].to( + dtype).clone().detach().contiguous() + + # saving the state dict as per the tp_rank and pp_rank + for tp_rank in range(args.target_tensor_model_parallel_size): + output_state_dict[tp_rank]["checkpoint_version"] = 3.0 + output_state_dict[tp_rank]["args"] = margs + checkpoint_dir = ( + f"mp_rank_{tp_rank:02d}" + if args.target_pipeline_model_parallel_size == 1 + else f"mp_rank_{tp_rank:02d}_{pp_rank:03d}" + ) + if args.use_distributed_optimizer: + checkpoint_name = "model_optim_rng.pt" + else: + checkpoint_name = "model_optim_rng.pt" + output_state_dict[tp_rank]["optimizer"] = dummy_optim_state_dict["optimizer"] + checkpoint_dir = os.path.join(release_dir, checkpoint_dir) + os.makedirs(checkpoint_dir, exist_ok=True) + checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) + if args.print_checkpoint_structure: + print( + f"Checkpoint structure of model state dict shard belonging to TP rank {tp_rank} and PP rank" + f" {pp_rank}:" + ) + recursive_print(None, output_state_dict[tp_rank]) + torch.save(output_state_dict[tp_rank], checkpoint_path) + + copy_tokenizer(args=args) + + +def main(): + parser = argparse.ArgumentParser() + parser = add_checkpointing_args(parser) + parser = add_megatron_checkpoint_args(parser) + parser = add_transformers_checkpoint_args(parser) + args = parser.parse_args() + + if args.convert_checkpoint_from_megatron_to_transformers: + convert_checkpoint_from_megatron_to_transformers(args) + else: + convert_checkpoint_from_transformers_to_megatron(args) + + +if __name__ == "__main__": + main() diff --git a/tools/checkpoint_conversion/baichuan_hf_to_megatron.sh b/tools/checkpoint_conversion/baichuan_hf_to_megatron.sh new file mode 100644 index 0000000000..9b021a6fce --- /dev/null +++ b/tools/checkpoint_conversion/baichuan_hf_to_megatron.sh @@ -0,0 +1,10 @@ +python baichuan_checkpoint_conversion.py \ +--load_path "PATH_TO_CHECKPOINT_DOWNLOADED_FROM_HUGGINGFACE" \ +--save_path "PATH_TO_SAVE_CONVERTED_CHECKPOINT" \ +--target_tensor_model_parallel_size 2 \ +--target_pipeline_model_parallel_size 1 \ +--target_data_parallel_size 16 \ +--target_params_dtype "fp16" \ +--make_vocab_size_divisible_by 1 \ +--print-checkpoint-structure \ +--megatron-path "PATH_TO_MEGATRON_SOURCE_CODE" diff --git a/tools/checkpoint_conversion/baichuan_megatron_to_hf.sh b/tools/checkpoint_conversion/baichuan_megatron_to_hf.sh new file mode 100644 index 0000000000..3f47fe7e77 --- /dev/null +++ b/tools/checkpoint_conversion/baichuan_megatron_to_hf.sh @@ -0,0 +1,8 @@ +python tools/checkpoint_conversion/baichuan_checkpoint_conversion.py \ +--convert_checkpoint_from_megatron_to_transformers \ +--load_path "PATH_TO_CHECKPOINT_GENERATED_BY_THIS_REPO" \ +--save_path "PATH_TO_SAVE_CONVERTED_CHECKPOINT" \ +--target_params_dtype "fp16" \ +--make_vocab_size_divisible_by 1 \ +--print-checkpoint-structure \ +--megatron-path "PATH_TO_MEGATRON_SOURCE_CODE"