forked from FlagOpen/FlagPerf
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add llama2_70B-Megatron pretraining (FlagOpen#389)
* Add llama2_70B-Megatron pretraining * support vendor shell and update readme * add fp32 training performance data and update readme * fix&add * fix&add * fix&add * fix&add * update framework commit version and readme --------- Co-authored-by: shh2000 <[email protected]>
- Loading branch information
Showing
11 changed files
with
523 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
## 模型信息 | ||
- Introduction | ||
|
||
Llama 2, a collection of pretrained and fine-tuned large language models (LLMs) ranging in scale from 7 billion to 70 billion parameters. Meta's fine-tuned LLMs, called Llama 2-Chat, are optimized for dialogue use cases. Llama2 outperform open-source chat models on most benchmarks meta's researchers tested, and based on their human evaluations for helpfulness and safety, may be a suitable substitute for closedsource models. Meta provide a detailed description of their approach to fine-tuning and safety improvements of Llama 2-Chat in order to enable the community to build on their work and contribute to the responsible development of LLMs. | ||
|
||
- Paper | ||
[LLAMA2](https://arxiv.org/pdf/2307.09288.pdf) | ||
|
||
- 模型代码来源 | ||
|
||
This case includes code from the LLAMA 2 COMMUNITY LICENSE AGREEMENT License open source project at:https://github.com/facebookresearch/llama-recipes/tree/main | ||
|
||
|
||
## 数据准备 | ||
|
||
### 模型配置及tokenizer准备 | ||
|
||
本测试样例为预训练case,需要下载tokenizer,下载链接为 https://github.com/FlagOpen/FlagScale/tree/main/examples/llama2/tokenizer | ||
|
||
在data_dir下创建tokenizer目录,将上述链接中的tokenizer.model文件下载到此目录中 | ||
|
||
|
||
### 数据集准备 | ||
|
||
本测试样例数据使用FlagScale-llama2预处理好的数据集,下载链接为 | ||
|
||
https://model.ks3-cn-beijing.ksyuncs.com/nlpdata/pile_wikipedia_demo.bin | ||
|
||
https://model.ks3-cn-beijing.ksyuncs.com/nlpdata/pile_wikipedia_demo.idx | ||
|
||
将上述两个文件放置于data_dir下。 | ||
|
||
This case includes datasets from the MIT License open source project at https://github.com/EleutherAI/the-pile | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
### 数据集引用 | ||
|
||
``` | ||
@article{pile, | ||
title={The {P}ile: An 800GB Dataset of Diverse Text for Language Modeling}, | ||
author={Gao, Leo and Biderman, Stella and Black, Sid and Golding, Laurence and Hoppe, Travis and Foster, Charles and Phang, Jason and He, Horace and Thite, Anish and Nabeshima, Noa and Presser, Shawn and Leahy, Connor}, | ||
journal={arXiv preprint arXiv:2101.00027}, | ||
year={2020} | ||
} | ||
``` | ||
|
||
### 框架与芯片支持情况 | ||
| | Pytorch | | ||
| ---------- | ------- | | ||
| Nvidia GPU | ✅ | | ||
| 昆仑芯 XPU | N/A | | ||
| 天数智芯 | N/A | |
145 changes: 145 additions & 0 deletions
145
training/benchmarks/llama2_70B/megatron/megatron_main.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
export PYTHONPATH=$PYTHONPATH:/workspace/FlagScale | ||
export CUDA_DEVICE_MAX_CONNECTIONS=1 | ||
|
||
DATA_DIR=$1 | ||
GPUS_PER_NODE=$2 | ||
NNODES=$3 | ||
NODE_RANK=$4 | ||
MASTER_ADDR=$5 | ||
MASTER_PORT=$6 | ||
TRAIN_SAMPLES=$7 | ||
TP=$8 | ||
PP=$9 | ||
M_BATCHSIZE=${10} | ||
G_BATCHSIZE=${11} | ||
SEQLENGTH=${12} | ||
FLASH_ATTN=${13} | ||
RECOMPUTE=${14} | ||
VENDOR_SHELL=${15} | ||
|
||
echo $DATA_DIR | ||
echo $GPUS_PER_NODE | ||
echo $NNODES | ||
echo $NODE_RANK | ||
echo $MASTER_ADDR | ||
echo $MASTER_PORT | ||
echo $TRAIN_SAMPLES | ||
echo $TP | ||
echo $PP | ||
echo $M_BATCHSIZE | ||
echo $G_BATCHSIZE | ||
echo $SEQLENGTH | ||
echo $FLASH_ATTN | ||
echo $RECOMPUTE | ||
echo $VENDOR_SHELL | ||
|
||
DATA_PATH=$DATA_DIR/llama_00_text_document | ||
TOKENIZER_PATH=$DATA_DIR/tokenizer/tokenizer.model | ||
|
||
DISTRIBUTED_ARGS=" | ||
--nproc_per_node $GPUS_PER_NODE \ | ||
--nnodes $NNODES \ | ||
--node_rank $NODE_RANK \ | ||
--master_addr $MASTER_ADDR \ | ||
--master_port $MASTER_PORT | ||
" | ||
|
||
if [ "$FLASH_ATTN" = "True" ]; then | ||
TRAINING_ARGS=" | ||
--train-samples $TRAIN_SAMPLES \ | ||
--eval-iters 0 \ | ||
--tensor-model-parallel-size $TP \ | ||
--pipeline-model-parallel-size $PP \ | ||
--micro-batch-size $M_BATCHSIZE \ | ||
--global-batch-size $G_BATCHSIZE \ | ||
--disable-bias-linear \ | ||
--use-distributed-optimizer \ | ||
--use-flash-attn | ||
" | ||
else | ||
TRAINING_ARGS=" | ||
--train-samples $TRAIN_SAMPLES \ | ||
--eval-iters 0 \ | ||
--tensor-model-parallel-size $TP \ | ||
--pipeline-model-parallel-size $PP \ | ||
--micro-batch-size $M_BATCHSIZE \ | ||
--global-batch-size $G_BATCHSIZE \ | ||
--disable-bias-linear \ | ||
--use-distributed-optimizer | ||
" | ||
fi | ||
|
||
MIXED_PRECISION_ARGS=" | ||
--bf16 | ||
" | ||
|
||
DATA_ARGS=" | ||
--data-path $DATA_PATH \ | ||
--tokenizer-type Llama2Tokenizer \ | ||
--tokenizer-model $TOKENIZER_PATH \ | ||
--split 1 | ||
" | ||
|
||
NETWORK_ARGS=" | ||
--num-layers 80 \ | ||
--hidden-size 8192 \ | ||
--num-attention-heads 64 \ | ||
--ffn-hidden-size 28672 \ | ||
--seq-length $SEQLENGTH \ | ||
--max-position-embeddings $SEQLENGTH \ | ||
--normalization RMSNorm \ | ||
--group-query-attention \ | ||
--num-query-groups 8 \ | ||
--use-rotary-position-embeddings \ | ||
--no-position-embedding \ | ||
--swiglu \ | ||
--multiple-of 4096 \ | ||
--sequence-parallel \ | ||
--untie-embeddings-and-output-weights | ||
" | ||
|
||
if [ "$RECOMPUTE" = "True" ]; then | ||
RECOMPUTE_ARGS=" | ||
--recompute-activations | ||
" | ||
|
||
INITIALIZATION_ARGS=" | ||
--init-method-std 0.02 \ | ||
--seed 1234 | ||
" | ||
|
||
REGULARIZATION_ARGS=" | ||
--attention-dropout 0.0 \ | ||
--hidden-dropout 0.0 \ | ||
--weight-decay 1e-2 \ | ||
--adam-beta1 0.9 \ | ||
--adam-beta2 0.95 \ | ||
--clip-grad 1.0 | ||
" | ||
|
||
LEARNING_RATE_ARGS=" | ||
--lr 0.00015 \ | ||
--min-lr 1.0e-5 \ | ||
--lr-decay-style cosine \ | ||
--lr-warmup-fraction .01 | ||
" | ||
|
||
LOGGING_ARGS=" | ||
--log-interval 1 | ||
" | ||
|
||
source $VENDOR_SHELL | ||
cmd="torchrun $DISTRIBUTED_ARGS /workspace/FlagScale/pretrain_llama.py \ | ||
$TRAINING_ARGS \ | ||
$MIXED_PRECISION_ARGS \ | ||
$DATA_ARGS \ | ||
$NETWORK_ARGS \ | ||
$INITIALIZATION_ARGS \ | ||
$REGULARIZATION_ARGS \ | ||
$LEARNING_RATE_ARGS \ | ||
$CHECKPOINTING_ARGS \ | ||
$RECOMPUTE_ARGS \ | ||
$LOGGING_ARGS | ||
" | ||
echo $cmd | ||
eval $cmd |
92 changes: 92 additions & 0 deletions
92
training/benchmarks/llama2_70B/megatron/run_pretraining.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import subprocess | ||
from argparse import ArgumentParser | ||
import os | ||
import sys | ||
from importlib import import_module | ||
|
||
|
||
def parse_args(): | ||
'''we parse ddp related args, check system config args, and running env | ||
args such as --data_dir_xxx. Then pass all useful args to the real | ||
training script. | ||
''' | ||
parser = ArgumentParser(description="flagscale main python") | ||
parser.add_argument("--nproc_per_node", type=int, required=True) | ||
parser.add_argument("--nnodes", type=int, required=True) | ||
parser.add_argument("--node_rank", type=int, required=True) | ||
parser.add_argument("--master_addr", type=str, required=True) | ||
parser.add_argument("--master_port", type=int, required=True) | ||
parser.add_argument("--vendor", type=str, required=True) | ||
parser.add_argument("--data_dir", type=str, required=True) | ||
parser.add_argument("--log_dir", type=str, required=True) | ||
parser.add_argument("--flagperf_config_file", type=str, required=True) | ||
args, unknown_args = parser.parse_known_args() | ||
args.unknown_args = unknown_args | ||
return args | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
print(args) | ||
|
||
sys.path.append(os.path.dirname(args.flagperf_config_file)) | ||
config_file = os.path.basename(args.flagperf_config_file).split('.')[0] | ||
config_dir_path = os.path.dirname(args.flagperf_config_file) | ||
|
||
module = import_module(config_file) | ||
|
||
seqlength = getattr(module, 'seqlength') | ||
batchsize = getattr(module, 'batchsize') | ||
accumulate_steps = getattr(module, 'accumulate_steps') | ||
train_tokens = getattr(module, 'train_tokens') | ||
theoryflops = getattr(module, 'theoryflops') | ||
epochs = getattr(module, 'epochs') | ||
flashattn = getattr(module, 'flashattn') | ||
recompute = getattr(module, 'recompute') | ||
tensor_parallel = getattr(module, 'tensor_parallel') | ||
pipeline_parallel = getattr(module, 'pipeline_parallel') | ||
|
||
train_samples = int((train_tokens * epochs) // seqlength) | ||
mbs = batchsize | ||
gbs = batchsize * args.nproc_per_node * args.nnodes * accumulate_steps // (tensor_parallel * | ||
pipeline_parallel) | ||
|
||
task_log_file = os.path.join(args.log_dir, "megatron.log.txt") | ||
|
||
exec_cmd = "bash megatron_main.sh" | ||
exec_cmd = exec_cmd + " " + args.data_dir | ||
exec_cmd = exec_cmd + " " + str(args.nproc_per_node) | ||
exec_cmd = exec_cmd + " " + str(args.nnodes) | ||
exec_cmd = exec_cmd + " " + str(args.node_rank) | ||
exec_cmd = exec_cmd + " " + args.master_addr | ||
exec_cmd = exec_cmd + " " + str(args.master_port) | ||
exec_cmd = exec_cmd + " " + str(train_samples) | ||
exec_cmd = exec_cmd + " " + str(tensor_parallel) | ||
exec_cmd = exec_cmd + " " + str(pipeline_parallel) | ||
exec_cmd = exec_cmd + " " + str(mbs) | ||
exec_cmd = exec_cmd + " " + str(gbs) | ||
exec_cmd = exec_cmd + " " + str(seqlength) | ||
exec_cmd = exec_cmd + " " + str(flashattn) | ||
exec_cmd = exec_cmd + " " + str(recompute) | ||
exec_cmd = exec_cmd + " " + os.path.join(config_dir_path, "training_adapter.sh") | ||
|
||
with open(task_log_file, "w") as f: | ||
p = subprocess.Popen(exec_cmd, | ||
shell=True, | ||
stdout=f, | ||
stderr=subprocess.STDOUT) | ||
p.wait() | ||
|
||
time_per_step = -1.0 | ||
with open(task_log_file) as f: | ||
for line in f.readlines(): | ||
if "elapsed time per iteration (ms): " in line: | ||
info = line.split("|")[2] | ||
steptime = info.split(":")[1] | ||
time_per_step = float(steptime) / 1000 | ||
|
||
whole_tps = gbs * seqlength / time_per_step | ||
chip_tps = whole_tps / (args.nproc_per_node * args.nnodes) | ||
print("System tokens per second: ", whole_tps) | ||
print("Tokens/p/s: ", chip_tps) | ||
print("MFU: ", chip_tps * 70000000000.0 * 6 / theoryflops) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
FROM nvcr.io/nvidia/pytorch:23.09-py3 | ||
RUN /bin/bash -c "pip config set global.index-url https://mirror.baidu.com/pypi/simple" | ||
RUN /bin/bash -c "uname -a" | ||
RUN /bin/bash -c alias python3=python |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/bin/bash | ||
# using github mirrors to avoid github TTL | ||
git clone https://githubfast.com/FlagOpen/FlagScale | ||
git checkout 26cd6643c472f853e077779abaa51bb6a1c140bf | ||
echo 'export PYTHONPATH=$PYTHONPATH:/workspace/FlagScale' >> /root/.bashrc | ||
source /root/.bashrc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
|
||
### Nvidia GPU配置与运行信息参考 | ||
#### 环境配置 | ||
- ##### 硬件环境 | ||
- 机器型号: NVIDIA H800(80G) | ||
- 加速卡型号: NVIDIA_H800-80GB | ||
- CPU型号: Intel(R) Xeon(R) Platinum 8462Y+ | ||
- 多机网络类型、带宽: InfiniBand, 200Gb/s | ||
|
||
- ##### 软件环境 | ||
- OS版本:Ubuntu 22.04 LTS | ||
- OS kernel版本: 5.15.0-25-generic | ||
- 加速卡驱动版本:535.129.03 | ||
- Docker 版本:24.0.7 | ||
- 训练框架版本:FlagScale.git@26cd664 | ||
- 依赖软件版本:sentencepiece | ||
|
||
- ##### 并行策略 | ||
|
||
- 并行技术:张量、流水、数据混合并行,具体并行方案见“运行情况”章节 | ||
- 实施者:FlagScale | ||
- 实施细节:/ | ||
|
||
- ##### 优化策略 | ||
|
||
- flash attention 2 | ||
|
||
### 运行情况 | ||
|
||
* 输入批尺寸 | ||
1. local_batchsize(micro_batchsize),简写为LBS,即实际进入模型的张量批尺寸,为config_H100x4x8.py中所写,在本case中默认为1 | ||
2. seqlength(max_position_embedding),简写为MPE,即实际进入模型的序列长度,为config_H100x4x8.py中所写,在本case中默认为4096 | ||
3. gradient_accumulate_steps,简写为GAS,即梯度累加步数,为ds_config.json中所写,在本case中默认为44 | ||
4. global_batchsize恒等于local_batchsize\*gradient_accumulate_steps\*data_parallel_size。在本case中,data_parallel_size=world_size/TPsize/PPsize。 | ||
|
||
* 通用指标 | ||
|
||
| 指标名称 | 指标值 | 特殊说明 | | ||
| ------------ | -------------------------- | ---------------------------------- | | ||
| 任务类别 | 自然语言理解 | | | ||
| 模型 | llama2_70b | | | ||
| 数据集 | pile wikipedia | | | ||
| 数据精度 | precision,见“性能指标” | 可选fp32/amp/fp16/bf16 | | ||
| 超参修改 | parallel,见“性能指标” | 格式为TPxPPyDPz,例如TP2PP1DP4 | | ||
| 超参修改 | fix_hp,见“性能指标” | 跑满硬件设备评测吞吐量所需特殊超参 | | ||
| 硬件设备简称 | nvidia H800 | | | ||
| 硬件存储使用 | mem,见“性能指标” | 通常称为“显存”,单位为GiB | | ||
| 计算使用率 | MFU,见“性能指标” | 参见PaLM论文定义 | | ||
| **吞吐量** | **token/p/s,见“性能指标”** | 平均单卡每秒处理的token数 | | ||
|
||
* 性能指标 | ||
|
||
值得注意的是,下列第4组实验的global_batchsize与llama2原始论文相同, 训练100 step,此项实验也将作为精度对齐所用实验。 | ||
|
||
| 配置 | precision | parallel | fix_hp | token/p/s | loss | mem | MFU | | ||
| ------------------ | -------- | --------- | ---------------- | ------ | ------- | --------- | --------- | | ||
| H800四机32卡(4x8) | fp32 | TP8PP4DP1 | recompute=True,(theoryflops=495T) | 253.61 | 0.94 | 77/80 | 21.5% | | ||
| H800四机32卡(4x8) | amp | TP8PP4DP1 | / | 641.93 | 5.7 | 62/80 | 27.2% | | ||
| H800四机32卡(4x8) | amp | TP4PP8DP1 | / | 791.37 | 5.6 | 74/80 | 33.6% | | ||
| H800四机32卡(4x8) | amp | TP4PP8DP1 | GAS=1024(GBS=1024=4M tokens) | 908.29 | 7.1 | 74/80 | 38.6% | |
10 changes: 10 additions & 0 deletions
10
training/nvidia/llama2_70B-megatron/config/config_H800x4x8.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
seqlength = 4096 | ||
batchsize = 1 | ||
accumulate_steps = 44 | ||
train_tokens = 100000000 | ||
theoryflops = 989000000000000.0 | ||
epochs = 1 | ||
flashattn = True | ||
recompute = False | ||
tensor_parallel = 8 | ||
pipeline_parallel = 4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
sentencepiece |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
echo "[Prompt] nvidia adaption is NULL, for other Vendors" |
Oops, something went wrong.