From 0171acaa18136b79b43dedf479f6287552377c9a Mon Sep 17 00:00:00 2001 From: jamesruio <1428173426@qq.com> Date: Thu, 11 Jan 2024 22:00:57 +0800 Subject: [PATCH] add fp32 training performance data and update readme --- .../llama2_70B/megatron/megatron_main.sh | 16 +- .../llama2_70B/megatron/pretrain_llama.py | 189 ------------------ .../llama2_70B/megatron/run_pretraining.py | 4 +- .../docker_image/megatron/megatron_install.sh | 2 +- training/nvidia/llama2_70B-megatron/README.md | 11 +- .../config/config_H800x4x8.py | 1 + .../run_benchmarks/config/cluster_conf.py | 2 +- 7 files changed, 26 insertions(+), 199 deletions(-) delete mode 100644 training/benchmarks/llama2_70B/megatron/pretrain_llama.py diff --git a/training/benchmarks/llama2_70B/megatron/megatron_main.sh b/training/benchmarks/llama2_70B/megatron/megatron_main.sh index 99ceb3591..a5f01b8ee 100644 --- a/training/benchmarks/llama2_70B/megatron/megatron_main.sh +++ b/training/benchmarks/llama2_70B/megatron/megatron_main.sh @@ -14,7 +14,8 @@ M_BATCHSIZE=${10} G_BATCHSIZE=${11} SEQLENGTH=${12} FLASH_ATTN=${13} -VENDOR_SHELL=${14} +RECOMPUTE=${14} +VENDOR_SHELL=${15} echo $DATA_DIR echo $GPUS_PER_NODE @@ -29,6 +30,7 @@ echo $M_BATCHSIZE echo $G_BATCHSIZE echo $SEQLENGTH echo $FLASH_ATTN +echo $RECOMPUTE echo $VENDOR_SHELL DATA_PATH=$DATA_DIR/llama_00_text_document @@ -96,6 +98,11 @@ NETWORK_ARGS=" --untie-embeddings-and-output-weights " +if [ "$RECOMPUTE" = "True" ]; then + RECOMPUTE_ARGS=" + --recompute-activations + " + INITIALIZATION_ARGS=" --init-method-std 0.02 \ --seed 1234 @@ -117,8 +124,12 @@ LEARNING_RATE_ARGS=" --lr-warmup-fraction .01 " +LOGGING_ARGS=" + --log-interval 1 +" + source $VENDOR_SHELL -cmd="torchrun $DISTRIBUTED_ARGS pretrain_llama.py \ +cmd="torchrun $DISTRIBUTED_ARGS /workspace/FlagScale/pretrain_llama.py \ $TRAINING_ARGS \ $MIXED_PRECISION_ARGS \ $DATA_ARGS \ @@ -127,6 +138,7 @@ cmd="torchrun $DISTRIBUTED_ARGS pretrain_llama.py \ $REGULARIZATION_ARGS \ $LEARNING_RATE_ARGS \ $CHECKPOINTING_ARGS \ + $RECOMPUTE_ARGS \ $LOGGING_ARGS " echo $cmd diff --git a/training/benchmarks/llama2_70B/megatron/pretrain_llama.py b/training/benchmarks/llama2_70B/megatron/pretrain_llama.py deleted file mode 100644 index b4e5142af..000000000 --- a/training/benchmarks/llama2_70B/megatron/pretrain_llama.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -"""Pretrain LLaMA.""" - -import os -import torch -from torch import Tensor -from functools import partial -from megatron import get_args -from megatron import print_rank_0 -from megatron import get_timers -from megatron import get_tokenizer -from megatron.core import mpu, tensor_parallel -from megatron.core.enums import ModelType -from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder -from megatron.core.datasets.gpt_dataset import GPTDatasetConfig -from megatron.core.datasets.gpt_dataset import GPTDataset -from megatron.model import LLaMAModel -from megatron.training import pretrain -from megatron.utils import ( - get_ltor_masks_and_position_ids, - get_batch_on_this_cp_rank, - average_losses_across_data_parallel_group -) -from megatron.arguments import core_transformer_config_from_args - - -def model_provider(pre_process=True, post_process=True): - """Build the model.""" - args = get_args() - config = core_transformer_config_from_args(args) - print_rank_0('building LLaMA model ...') - model = LLaMAModel( - config=config, - num_tokentypes=0, - parallel_output=True, - pre_process=pre_process, - post_process=post_process - ) - return model - - -def get_batch(data_iterator): - """Generate a batch.""" - - # TODO: this is pretty hacky, find a better way - if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): - return None, None, None, None, None - - args = get_args() - tokenizer = get_tokenizer() - - # Items and their type. - keys = ['text'] - datatype = torch.int64 - - # Broadcast data. - if data_iterator is not None: - data = next(data_iterator) - else: - data = None - data_b = tensor_parallel.broadcast_data(keys, data, datatype) - - # Unpack. - tokens_ = data_b['text'].long() - labels = tokens_[:, 1:].contiguous() - tokens = tokens_[:, :-1].contiguous() - - # Get the masks and postition ids. - attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( - tokens, - tokenizer.eod, - args.reset_position_ids, - args.reset_attention_mask, - args.eod_mask_loss) - - batch = { - 'tokens': tokens, - 'labels': labels, - 'loss_mask': loss_mask, - 'attention_mask': attention_mask, - 'position_ids': position_ids - } - # slice batch along sequence dimension for context parallelism - batch = get_batch_on_this_cp_rank(batch) - - return batch.values() - -def loss_func(loss_mask: Tensor, output_tensor: Tensor): - """Loss function. - - Args: - loss_mask (Tensor): Used to mask out some portions of the loss - output_tensor (Tensor): The tensor with the losses - """ - args = get_args() - - losses = output_tensor.float() - loss_mask = loss_mask.view(-1).float() - if args.context_parallel_size > 1: - loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)]) - torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group()) - loss = loss[0] / loss[1] - else: - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() - - # Check individual rank losses are not NaN prior to DP all-reduce. - if args.check_for_nan_in_loss_and_grad: - global_rank = torch.distributed.get_rank() - assert not loss.isnan(), ( - f'Rank {global_rank}: found NaN in local forward loss calculation. ' - f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}' - ) - - # Reduce loss for logging. - averaged_loss = average_losses_across_data_parallel_group([loss]) - - return loss * args.context_parallel_size, {'lm loss': averaged_loss[0]} - - -def forward_step(data_iterator, model: LLaMAModel): - """Forward training step. - - Args: - data_iterator : Input data iterator - model (LLaMAModel): The LLaMA Model - """ - args = get_args() - timers = get_timers() - - # Get the batch. - timers('batch-generator', log_level=2).start() - tokens, labels, loss_mask, attention_mask, position_ids = get_batch( - data_iterator) - timers('batch-generator').stop() - - output_tensor = model(tokens, position_ids, attention_mask, - labels=labels) - - return output_tensor, partial(loss_func, loss_mask) - - -def is_dataset_built_on_rank(): - return (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and mpu.get_tensor_model_parallel_rank() == 0 - - -def core_llama_dataset_config_from_args(args): - return GPTDatasetConfig( - is_built_on_rank=is_dataset_built_on_rank, - random_seed=args.seed, - sequence_length=args.seq_length, - blend=args.data_path, - blend_per_split=[args.train_data_path, args.valid_data_path, args.test_data_path], - split=args.split, - path_to_cache=args.data_cache_path, - ) - - -def train_valid_test_datasets_provider(train_val_test_num_samples): - """Build the train test and validation datasets. - - Args: - train_val_test_num_samples : A list containing the number of samples in train test and validation. - """ - args = get_args() - - print_rank_0("> building train, validation, and test datasets for GPT ...") - - train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( - GPTDataset, - train_val_test_num_samples, - core_llama_dataset_config_from_args(args) - ).build() - - print_rank_0("> finished creating LLaMA datasets ...") - - return train_ds, valid_ds, test_ds - - -if __name__ == "__main__": - - # Temporary for transition to core datasets - train_valid_test_datasets_provider.is_distributed = True - - pretrain(train_valid_test_datasets_provider, - model_provider, - ModelType.encoder_or_decoder, - forward_step, - args_defaults={'tokenizer_type': 'Llama2Tokenizer'}, - get_batch_fn=get_batch) diff --git a/training/benchmarks/llama2_70B/megatron/run_pretraining.py b/training/benchmarks/llama2_70B/megatron/run_pretraining.py index 0501e2319..5d7cea938 100644 --- a/training/benchmarks/llama2_70B/megatron/run_pretraining.py +++ b/training/benchmarks/llama2_70B/megatron/run_pretraining.py @@ -42,6 +42,7 @@ def parse_args(): theoryflops = getattr(module, 'theoryflops') epochs = getattr(module, 'epochs') flashattn = getattr(module, 'flashattn') + recompute = getattr(module, 'recompute') tensor_parallel = getattr(module, 'tensor_parallel') pipeline_parallel = getattr(module, 'pipeline_parallel') @@ -66,6 +67,7 @@ def parse_args(): exec_cmd = exec_cmd + " " + str(gbs) exec_cmd = exec_cmd + " " + str(seqlength) exec_cmd = exec_cmd + " " + str(flashattn) + exec_cmd = exec_cmd + " " + str(recompute) exec_cmd = exec_cmd + " " + os.path.join(config_dir_path, "training_adapter.sh") with open(task_log_file, "w") as f: @@ -87,4 +89,4 @@ def parse_args(): chip_tps = whole_tps / (args.nproc_per_node * args.nnodes) print("System tokens per second: ", whole_tps) print("Tokens/p/s: ", chip_tps) - print("MFU: ", chip_tps * 7000000000.0 * 6 / theoryflops) \ No newline at end of file + print("MFU: ", chip_tps * 70000000000.0 * 6 / theoryflops) \ No newline at end of file diff --git a/training/nvidia/docker_image/megatron/megatron_install.sh b/training/nvidia/docker_image/megatron/megatron_install.sh index ba8f157a5..073708b43 100644 --- a/training/nvidia/docker_image/megatron/megatron_install.sh +++ b/training/nvidia/docker_image/megatron/megatron_install.sh @@ -1,5 +1,5 @@ #!/bin/bash # using github mirrors to avoid github TTL -git clone -b kunlunxin_llama70B https://github.com/jamesruio/FlagScale.git +git clone https://githubfast.com/FlagOpen/FlagScale echo 'export PYTHONPATH=$PYTHONPATH:/workspace/FlagScale' >> /root/.bashrc source /root/.bashrc \ No newline at end of file diff --git a/training/nvidia/llama2_70B-megatron/README.md b/training/nvidia/llama2_70B-megatron/README.md index 4fc9965f3..185746635 100644 --- a/training/nvidia/llama2_70B-megatron/README.md +++ b/training/nvidia/llama2_70B-megatron/README.md @@ -40,7 +40,7 @@ | 任务类别 | 自然语言理解 | | | 模型 | llama2_70b | | | 数据集 | pile wikipedia | | -| 数据精度 | amp | | +| 数据精度 | precision,见“性能指标” | 可选fp32/amp/fp16/bf16 | | 超参修改 | parallel,见“性能指标” | 格式为TPxPPyDPz,例如TP2PP1DP4 | | 超参修改 | fix_hp,见“性能指标” | 跑满硬件设备评测吞吐量所需特殊超参 | | 硬件设备简称 | nvidia H800 | | @@ -50,7 +50,8 @@ * 性能指标 -| 配置 | parallel | fix_hp | token/p/s | loss | mem | MFU | -| ------------------- | ------ | ---------------- | ------ | ------- | --------- | --------- | -| H800单机8卡(4x8) | TP8PP4DP1 | / | 641.93 | 5.7 | 62/80 | 27.2% | -| H800单机8卡(4x8) | TP4PP8DP1 | / | 791.37 | 5.6 | 74/80 | 33.6% | \ No newline at end of file +| 配置 | precision | parallel | fix_hp | token/p/s | loss | mem | MFU | +| ------------------ | -------- | --------- | ---------------- | ------ | ------- | --------- | --------- | +| H800单机8卡(4x8) | fp32 | TP8PP4DP1 | recompute=True | 253.61 | 0.94 | 77/80 | 10.7% | +| H800单机8卡(4x8) | amp | TP8PP4DP1 | / | 641.93 | 5.7 | 62/80 | 27.2% | +| H800单机8卡(4x8) | amp | TP4PP8DP1 | / | 791.37 | 5.6 | 74/80 | 33.6% | diff --git a/training/nvidia/llama2_70B-megatron/config/config_H800x4x8.py b/training/nvidia/llama2_70B-megatron/config/config_H800x4x8.py index 63044c545..fe8af8395 100644 --- a/training/nvidia/llama2_70B-megatron/config/config_H800x4x8.py +++ b/training/nvidia/llama2_70B-megatron/config/config_H800x4x8.py @@ -5,5 +5,6 @@ theoryflops = 989000000000000.0 epochs = 1 flashattn = True +recompute = False tensor_parallel = 8 pipeline_parallel = 4 \ No newline at end of file diff --git a/training/run_benchmarks/config/cluster_conf.py b/training/run_benchmarks/config/cluster_conf.py index 0c184df36..be628e197 100644 --- a/training/run_benchmarks/config/cluster_conf.py +++ b/training/run_benchmarks/config/cluster_conf.py @@ -1,7 +1,7 @@ '''Cluster configs''' # Hosts to run the benchmark. Each item is an IP address or a hostname. -HOSTS = ["192.2.32.13", "192.2.32.14", "192.2.32.2", "192.2.32.4"] +HOSTS = ["10.1.2.2", "10.1.2.3", "10.1.2.4"] # Hosts port to run the tensorflow distribution_strategy = 'multi_worker_mirrored' HOSTS_PORTS = ["2222"]