Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add alibi position embedding and support baichuan #54

Open
wants to merge 48 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
a3e55c1
Update transformer.py
qyccc Dec 1, 2023
cf7c168
Create pretrain_baichuan.py
qyccc Dec 1, 2023
9ffe45c
Create baichuan_model.py
qyccc Dec 1, 2023
852f44b
Create baichuan_checkpoint_conversion.py
qyccc Dec 1, 2023
e08a890
Create baichuan_hf_to_megatron.sh
qyccc Dec 1, 2023
19460d8
Update baichuan_hf_to_megatron.sh
qyccc Dec 1, 2023
642599f
Create baichuan_megatron_to_hf.sh
qyccc Dec 1, 2023
af89fda
Update README.md
qyccc Dec 1, 2023
46ae7dd
Update README.md
qyccc Dec 1, 2023
eb1a258
Update README_zh.md
qyccc Dec 1, 2023
8330e86
Update README_zh.md
qyccc Dec 1, 2023
86d996b
Update README.md
qyccc Dec 1, 2023
6ccfb61
Update arguments.py
qyccc Dec 1, 2023
b0ac4d3
add collate_fn
qyccc Dec 11, 2023
09d01a8
Update training.py
qyccc Dec 11, 2023
68c77a5
add collate_fn
qyccc Dec 11, 2023
4c582c0
add alibi
qyccc Dec 11, 2023
3d6b633
Update tokenizer.py
qyccc Dec 11, 2023
9b709dc
Update transformer.py
qyccc Dec 11, 2023
a167ea2
Create Baichuan_13_standalone.sh
qyccc Dec 14, 2023
d5fbd89
Update README.md
qyccc Dec 14, 2023
448de2b
Update __init__.py
qyccc Dec 14, 2023
e2af6ca
solve Unsupported gpu architecture 'compute_90'
qyccc Dec 14, 2023
aad5993
Update Baichuan_13_standalone.sh
qyccc Dec 19, 2023
d9140f4
Update arguments.py
qyccc Dec 19, 2023
48538ce
Update transformer.py
qyccc Dec 19, 2023
531c2e4
Update utils.py
qyccc Dec 19, 2023
ca53b5e
Update tokenizer.py
qyccc Dec 19, 2023
b4443d8
Update tokenizer.py
qyccc Dec 19, 2023
ae70b9b
Update arguments.py
qyccc Dec 19, 2023
c2e1867
Update Baichuan_13_standalone.sh
qyccc Dec 19, 2023
60f218f
Update training.py
qyccc Dec 19, 2023
917992c
support flash attention v1
qyccc Dec 20, 2023
3b36b50
Update Baichuan_13_standalone.sh
qyccc Dec 20, 2023
64255e3
Update arguments.py
qyccc Dec 20, 2023
c6607c2
initialize certain Tensors on GPU
qyccc Dec 20, 2023
39ecc65
Update transformer.py
qyccc Dec 21, 2023
975bac6
Update __init__.py
qyccc Dec 26, 2023
5b8fca1
Update Baichuan_13_standalone.sh
qyccc Dec 26, 2023
d636039
remove collate_fn
qyccc Jan 2, 2024
0dd0c21
Update indentent in transformer.py
qyccc Jan 2, 2024
56684aa
Update data_samplers.py
qyccc Jan 2, 2024
451ffab
Update training.py
qyccc Jan 2, 2024
b4a2133
Update transformer.py
qyccc Jan 2, 2024
7cf20bd
Update transformer.py
qyccc Jan 2, 2024
5af1af2
Update utils.py
qyccc Jan 2, 2024
7de3e0a
Update transformer.py
qyccc Jan 2, 2024
251d6c5
Update transformer.py
qyccc Jan 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,23 +89,38 @@ This tool helps convert the format of paramters between Megatron-LLaMA/Megatron-

**HuggingFace to Megatron-LLaMA**

For LLaMA:
```
sh tools/checkpoint_conversion/hf_to_megatron.sh
sh tools/checkpoint_conversion/hf_to_megatron.sh
```
For Baichuan:
```
sh tools/checkpoint_conversion/baichuan_hf_to_megatron.sh
```

**Megatron-LLaMA to HuggingFace**

For LLaMA:
```
sh tools/checkpoint_conversion/megatron_to_hf.sh
```
For Baichuan:
```
sh tools/checkpoint_conversion/baichuan_megatron_to_hf.sh
```

#### B. Launching scripts

**Single-node launching**

For LLaMA:
```
sh examples/LLaMA/LLaMA_13_standalone.sh
```
For Baichuan:
```
sh examples/Baichuan_13_standalone.sh
```

**Distributed launching**

Expand Down
10 changes: 10 additions & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,27 @@ Megatron-LLaMA使用方式与Megatron-LM基本一致,详细信息请参考[Meg

**HuggingFace to Megatron-LLaMA**

LLaMA:
```
sh tools/checkpoint_conversion/hf_to_megatron.sh
```
Baichuan:
```
sh tools/checkpoint_conversion/baichuan_hf_to_megatron.sh
```

完成训练后,将训练产出的权重转换成HuggingFace支持的格式,方便后续使用:

**Megatron-LLaMA to HuggingFace**

LLaMA:
```
sh tools/checkpoint_conversion/megatron_to_hf.sh
```
Baichuan:
```
sh tools/checkpoint_conversion/baichuan_megatron_to_hf.sh
```

### B. LLaMA训练脚本

Expand Down
90 changes: 90 additions & 0 deletions examples/Baichuan_13_standalone.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
DATASET_1="<PATH TO THE FIRST DATASET>"
DATASET_2="<PATH TO THE SECOND DATASET>"
DATASET_3="<PATH TO THE THIRD DATASET>"
DATASET="0.2 ${DATASET_1} 0.3 ${DATASET_2} 0.5 ${DATASET_3}"

TP_SIZE=2
PP_SIZE=1
WORLD_SIZE=8
MICRO_BATCH_SIZE=2
# The int is the number of micro steps of gradient accumulation
GLOBAL_BATCH_SIZE=$((($WORLD_SIZE * $MICRO_BATCH_SIZE) / ($TP_SIZE * $PP_SIZE) * 8))
# GLOBAL_BATCH_SIZE=128

JOB_NAME="LLaMA_tp${TP_SIZE}_pp${PP_SIZE}_mbs${MICRO_BATCH_SIZE}_gpus${WORLD_SIZE}"
qyccc marked this conversation as resolved.
Show resolved Hide resolved

LOAD_CHECKPOINT_PATH="PATH TO THE MODEL CHECKPOINT"
SAVE_CHECKPOINT_PATH="PATH TO SAVE MODEL CHECKPOINT"
TOKENIZER_PATH="PATH OR NAME FOR PRETRAINED TOKENIZER"
TENSORBOARD_DIR="TENSORBOARD DIRECTORY"

TRAIN_ITERS=1000
EVAL_ITERS=10
EVAL_INTERVAL=1000
SAVE_INTERVAL=100
LOG_INTERVAL=1

# Setting --tensorboard-queue-size to 1 significantly slows down the training
options=" \
--finetune \
--sequence-parallel \
--tensor-model-parallel-size ${TP_SIZE} \
--pipeline-model-parallel-size ${PP_SIZE} \
--num-layers 40 \
--hidden-size 5120 \
--num-attention-heads 40 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--no-position-embedding \
--position-embedding-type alibi \
--swiglu \
--ffn-hidden-size 13696 \
--disable-bias-linear \
--RMSNorm \
--attention-dropout 0 \
--hidden-dropout 0 \
--layernorm-epsilon 1e-6 \
--causal-lm \
--tokenizer-type PretrainedFromHF \
--tokenizer-name-or-path $TOKENIZER_PATH \
--make-vocab-size-divisible-by 1 \
--init-method-std 0.01 \
--micro-batch-size ${MICRO_BATCH_SIZE} \
--global-batch-size ${GLOBAL_BATCH_SIZE} \
--train-iters ${TRAIN_ITERS} \
--lr 6.0e-5 \
--lr-decay-iters 10 \
--lr-warmup-iters 5 \
--min-lr 6.0e-6 \
--override-opt_param-scheduler \
--lr-decay-style cosine \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--clip-grad 1.0 \
--weight-decay 0.1 \
--overlapped-distributed-optimizer \
--reduce-bucket-size=2e8 \
--no-gradient-accumulation-fusion \
--dataloader-type cyclic \
--data-impl mmap \
--data-path ${DATASET} \
--split 98,2,0 \
--eval-interval ${EVAL_INTERVAL} \
--eval-iters ${EVAL_ITERS} \
--save-interval ${SAVE_INTERVAL} \
--save ${SAVE_CHECKPOINT_PATH} \
--load ${LOAD_CHECKPOINT_PATH} \
--no-load-optim \
--log-interval ${LOG_INTERVAL} \
--tensorboard-dir ${TENSORBOARD_DIR} \
--tensorboard-queue-size 1000 \
--log-timers-to-tensorboard \
--log-batch-size-to-tensorboard \
--log-validation-ppl-to-tensorboard \
--job-name ${JOB_NAME} \
--bf16 \
--recompute-activations \
--recompute-granularity selective \
--use-flash-attn"

torchrun --nproc_per_node=8 --master_port=29500 pretrain_llama.py ${options}
3 changes: 3 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,9 @@ def _add_network_size_args(parser):
'attention. This is set to '
' args.hidden_size // args.num_attention_heads '
'if not provided.')
group.add_argument('--position-embedding-type', type=str, default='learned_absolute',
qyccc marked this conversation as resolved.
Show resolved Hide resolved
choices=['learned_absolute', 'rope', 'alibi'],
help='Position embedding type.')
group.add_argument('--max-position-embeddings', type=int, default=None,
help='Maximum number of position embeddings to use. '
'This is the size of position embedding.')
Expand Down
11 changes: 7 additions & 4 deletions megatron/data/data_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from megatron.core import mpu


def build_pretraining_data_loader(dataset, consumed_samples):
def build_pretraining_data_loader(dataset, consumed_samples, data_collator=None):
"""Buld dataloader given an input dataset."""

if dataset is None:
Expand Down Expand Up @@ -40,10 +40,13 @@ def build_pretraining_data_loader(dataset, consumed_samples):
args.dataloader_type))

# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
dataloader = torch.utils.data.DataLoader(dataset,
collate_fn=data_collator,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
pin_memory=True)
return dataloader


class MegatronPretrainingSampler:

Expand Down Expand Up @@ -136,8 +139,8 @@ def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size,
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.micro_batch_size > 0, 'self.micro_batch_size > 0'
assert data_parallel_size > 0, 'data_parallel_size > 0'
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
Expand Down
6 changes: 3 additions & 3 deletions megatron/fused_kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def load(args):
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
if int(bare_metal_minor) >= 7:
qyccc marked this conversation as resolved.
Show resolved Hide resolved
cc_flag.append('-gencode')
cc_flag.append('arch=compute_90,code=sm_90')
# if int(bare_metal_minor) >= 7:
# cc_flag.append('-gencode')
# cc_flag.append('arch=compute_90,code=sm_90')

# Build path
srcpath = pathlib.Path(__file__).parent.absolute()
Expand Down
1 change: 1 addition & 0 deletions megatron/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .bert_model import BertModel
from .gpt_model import GPTModel
from .llama_model import LLaMAModel
from .baichuan_model import BaichuanModel
from .t5_model import T5Model
from .language_model import get_language_model
from .module import Float16Module
169 changes: 169 additions & 0 deletions megatron/model/baichuan_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (c) 2023, ALIBABA CORPORATION. All rights reserved.


"""Baichuan model."""

import torch

from megatron import get_args
from megatron.core import tensor_parallel
from .module import MegatronModule

from .enums import AttnMaskType
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .utils import init_method_normal
from .utils import scaled_init_method_normal


def post_language_model_processing(lm_output, labels, logit_weights,
parallel_output,
fp16_lm_cross_entropy):

# Output. Format [s b h]
output = parallel_lm_logits(
lm_output,
logit_weights,
parallel_output)

if labels is None:
# [s b h] => [b s h]
return output.transpose(0,1).contiguous()
else:
# [b s] => [s b]
labels = labels.transpose(0,1).contiguous()
if fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
else:
loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)

# [s b] => [b, s]
loss = loss.transpose(0,1).contiguous()
return loss


class BaichuanModel(MegatronModule):
"""Baichuan Language model."""

def __init__(self,
num_tokentypes=0,
parallel_output=True,
pre_process=True,
post_process=True):
args = get_args()
super(BaichuanModel, self).__init__(share_word_embeddings=not args.untie_embeddings_and_output_weights)

self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights
self.sequence_parallel = args.sequence_parallel
self.padded_vocab_size = args.padded_vocab_size

self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=False,
encoder_attn_mask_type=AttnMaskType.causal,
init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers),
pre_process=self.pre_process,
post_process=self.post_process)

self.causal_lm = args.causal_lm

if not args.untie_embeddings_and_output_weights and not self.causal_lm:
self.initialize_word_embeddings(init_method_normal)

if self.causal_lm and self.post_process:
self.lm_head = torch.nn.Linear(args.hidden_size, args.padded_vocab_size, bias=False)

def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)

def _causal_lm_process(self, lm_output, labels):
if self.sequence_parallel:
lm_output = tensor_parallel.gather_from_sequence_parallel_region(lm_output, False)
lm_output = lm_output.transpose(0, 1)
logits = self.lm_head(lm_output)

if labels is None:
return logits
else:
loss = None
# [invalid] Shift so that tokens < n predict n
# Do not need to shift here
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., :-1].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=0)
shift_logits = shift_logits.view(-1, self.padded_vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

return loss

def forward(self, input_ids, position_ids, attention_mask,
ret_input_ids=None, ret_position_ids=None, ret_attn_mask=None,
labels=None, tokentype_ids=None, inference_params=None):
lm_output = self.language_model(
input_ids,
position_ids,
attention_mask,
ret_input_ids=ret_input_ids,
ret_position_ids=ret_position_ids,
ret_attn_mask=ret_attn_mask,
inference_params=inference_params)

if self.post_process:
if self.causal_lm:
return self._causal_lm_process(lm_output=lm_output, labels=labels)
else:
return post_language_model_processing(
lm_output, labels,
self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.word_embeddings_weight(),
self.parallel_output,
self.fp16_lm_cross_entropy)
else:
return lm_output

def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):

state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
prefix=prefix, keep_vars=keep_vars)
# Save word_embeddings.
if (self.post_process
and not self.pre_process
and not self.untie_embeddings_and_output_weights
and not self.causal_lm):
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
if self.post_process and self.causal_lm:
state_dict_['lm_head'] = self.lm_head.state_dict()

return state_dict_

def load_state_dict(self, state_dict, strict=True):
"""Customized load."""

if self.causal_lm and self.post_process:
self.lm_head.load_state_dict(state_dict['lm_head'], strict=strict)

# Load word_embeddings.
if self.post_process and \
not self.pre_process \
and not self.untie_embeddings_and_output_weights \
and not self.causal_lm:
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
if self._language_model_key in state_dict:
state_dict = state_dict[self._language_model_key]
self.language_model.load_state_dict(state_dict, strict=strict)
Loading